pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.26k stars 420 forks source link

Could it support Gemma? #616

Closed solitude-alive closed 7 months ago

solitude-alive commented 7 months ago

The Google model have 2B model, it seems that we can use less than 4*24GB GPUs to fine-tune with full parameters. Do you plan to support it?

joecummings commented 7 months ago

We are considering many new model additions and will keep you posted!

kartikayk commented 7 months ago

@solitude-alive would you be open to adding this model? I'm happy to help share specific pointers and review code if you're interested. We'd love the contribution.

solitude-alive commented 7 months ago

@kartikayk Yeah, I'm happy to do that. I would try it.

kartikayk commented 7 months ago

@solitude-alive awesome!

For a starting point, take a look at the Mistral 7B model builder in: https://github.com/pytorch/torchtune/blob/main/torchtune/models/mistral/_model_builders.py

We expose specific models through model builders which basically stitich together components (eg: Attention, RoPE, RMS Norm etc). You can find some examples here: https://github.com/pytorch/torchtune/blob/main/torchtune/models/mistral/_component_builders.py#L36

I think adding support for gemma_2b would be similar. You just need to make sure the components line up with what Gemma is doing.

solitude-alive commented 7 months ago

@kartikayk Hi,

_model_builders.py and _component_builders.py have been mostly completed, except for some components that need to be confirmed.

Is there documentation on how to load the weights file? It seems that Gemma only support [model-00001-of-00002.safetensors, model-00002-of-00002.safetensors] rather than .bin or .pth files.

joecummings commented 7 months ago

@solitude-alive great catch! Right now, TorchTune supports only PyTorch-native .bin or .pt formats.

In order to add Gemma, we need to think about a functionality to support loading safetensors. Hugging Face has a great library and resources for this here: https://huggingface.co/docs/safetensors/index#usage and it probably makes sense to take a look at how we incorporate loading in TorchTune here: https://github.com/pytorch/torchtune/blob/6d9368fbba95a3753dc30627ba91b6d8f21dffe5/torchtune/utils/_checkpointing/_checkpointer_utils.py#L50.

Is this something you feel comfortable adding? This would be an incredible feature b/c there's a lot of other models on HF Hub that only support safetensors, too.

joecummings commented 7 months ago

Also, @solitude-alive - would love for you to join the Discord channel (see our README for invite link) so we can quickly answer any questions you may have as you work on this!

solitude-alive commented 7 months ago

@joecummings Yeah, thanks.

kartikayk commented 7 months ago

@solitude-alive Awesome! As @joecummings said, it would be awesome to add safetensor support to TorchTune's FullModelHfCheckpointer.

I verified that safetensors.safe_open produces the same state_dict with safetensor files as the TorchTune HF Checkpointer does with bin files for the llama-13B model. Here's a minimal validation:

 

# Examine safetensors

from safetensors import safe_open
from torchtune.models import convert_weights
from torchtune.utils import FullModelHFCheckpointer, ModelType

checkpoint_dir = '/data/users/kartikayk/cpts/Llama-2-13b-hf/'
safetensor_files = ['model-00001-of-00003.safetensors', 'model-00002-of-00003.safetensors', 'model-00003-of-00003.safetensors']
pytorch_files = ['pytorch_model-00001-of-00003.bin', 'pytorch_model-00002-of-00003.bin', 'pytorch_model-00003-of-00003.bin']

safetensor_sd = {}

for file in safetensor_files:
    file_path = checkpoint_dir + file
    with safe_open(file_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            safetensor_sd[key] = f.get_tensor(key)

# convert the state_dict from HF format to TorchTune format
# hf_to_tune needs to know some params for correct conversion
safetensor_sd_torchtune = convert_weights.hf_to_tune(safetensor_sd, num_heads=40, num_kv_heads=40, dim=5120)

# Use torchTune's HF Checkpointer to get the state_dict
checkpointer = FullModelHFCheckpointer(
    checkpoint_dir=checkpoint_dir,
    checkpoint_files=pytorch_files,
    output_dir='/data/users/kartikayk/cpts/Llama-2-13b-hf/',
    model_type=ModelType.LLAMA2
)

torchtune_sd = checkpointer.load_checkpoint()

# assert that we get the same keys and values
# torchtune checkpointer adds an additional 'model' key
for key in torchtune_sd['model'].keys():
    assert torch.equal(torchtune_sd['model'][key], safetensor_sd_torchtune[key])

And here's the output:

image

Given that these are numerically equivalent, I think the best way forward would be if you can add a flag to FullModelHFCheckpointer - something like is_safetensor and when this is True, just use ths safetensor.save_file instead of safe_torch_load to get the state_dict. Everything else, including the conversion to TorchTune's format should be the same. This is the relevant function: https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_checkpointing/_checkpointer.py#L323

Does this make sense to you?

solitude-alive commented 7 months ago

@kartikayk Thank you for your suggestion.

solitude-alive commented 7 months ago

Hi, it seems have errors on my device when I set the output_pro.weight = Tok_embedding.weight for Gemma. Is there any way to fix it?

[rank1]:[E ProcessGroupNCCL.cpp:523] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=21, OpType=BROADCAST, NumelIn=524290048, NumelOut=524290048, Timeout(ms)=600000) ran for 600431 milliseconds before timing out.
kartikayk commented 7 months ago

Seems like this is actively being discussed on the discord. Once the discussion is over, we can come back and summarize it here.

cc: @ebsmothers

solitude-alive commented 7 months ago

Hi, it seems have errors on my device when I set the output_pro.weight = Tok_embedding.weight for Gemma. Is there any way to fix it?

[rank1]:[E ProcessGroupNCCL.cpp:523] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=21, OpType=BROADCAST, NumelIn=524290048, NumelOut=524290048, Timeout(ms)=600000) ran for 600431 milliseconds before timing out.

Thanks for the discussion, there is a temporary solution: remove any weight tying that occurs before FSDP wrapping and put weight tying here. https://github.com/pytorch/torchtune/blob/73647e26eb327c1e7dff6d6d12e4060c16c11da9/recipes/full_finetune_distributed.py#L261

ebsmothers commented 7 months ago

Yeah to summarize the discussion on Discord: when training with FSDP the way we initialize the model undoes the weight tying. Specifically I suspect it's because we initialize on meta device. Not only that, but we cannot tie weights prior to FSDP wrapping or else we will hit a hang at our first sync point. You can see e.g. here for some discussion on the topic.

We can get around this by instead tying weights after FSDP wrapping. I believe @solitude-alive already has a tie_weight utility defined on their fork, we just need to call this in the recipe instead of the model builder. This way we can control when it gets executed; we can execute it anytime in our single device recipes, but need to execute it after FSDP wrapping in our distributed recipes. (Another option would be some kind of post-init hook but not sure offhand how to implement it.)