Closed solitude-alive closed 7 months ago
We are considering many new model additions and will keep you posted!
@solitude-alive would you be open to adding this model? I'm happy to help share specific pointers and review code if you're interested. We'd love the contribution.
@kartikayk Yeah, I'm happy to do that. I would try it.
@solitude-alive awesome!
For a starting point, take a look at the Mistral 7B model builder in: https://github.com/pytorch/torchtune/blob/main/torchtune/models/mistral/_model_builders.py
We expose specific models through model builders which basically stitich together components (eg: Attention, RoPE, RMS Norm etc). You can find some examples here: https://github.com/pytorch/torchtune/blob/main/torchtune/models/mistral/_component_builders.py#L36
I think adding support for gemma_2b would be similar. You just need to make sure the components line up with what Gemma is doing.
@kartikayk Hi,
_model_builders.py and _component_builders.py have been mostly completed, except for some components that need to be confirmed.
Is there documentation on how to load the weights file? It seems that Gemma only support [model-00001-of-00002.safetensors, model-00002-of-00002.safetensors] rather than .bin or .pth files.
@solitude-alive great catch! Right now, TorchTune supports only PyTorch-native .bin or .pt formats.
In order to add Gemma, we need to think about a functionality to support loading safetensors
. Hugging Face has a great library and resources for this here: https://huggingface.co/docs/safetensors/index#usage and it probably makes sense to take a look at how we incorporate loading in TorchTune here: https://github.com/pytorch/torchtune/blob/6d9368fbba95a3753dc30627ba91b6d8f21dffe5/torchtune/utils/_checkpointing/_checkpointer_utils.py#L50.
Is this something you feel comfortable adding? This would be an incredible feature b/c there's a lot of other models on HF Hub that only support safetensors, too.
Also, @solitude-alive - would love for you to join the Discord channel (see our README for invite link) so we can quickly answer any questions you may have as you work on this!
@joecummings Yeah, thanks.
@solitude-alive Awesome! As @joecummings said, it would be awesome to add safetensor support to TorchTune's FullModelHfCheckpointer
.
I verified that safetensors.safe_open
produces the same state_dict with safetensor
files as the TorchTune HF Checkpointer does with bin
files for the llama-13B model. Here's a minimal validation:
# Examine safetensors
from safetensors import safe_open
from torchtune.models import convert_weights
from torchtune.utils import FullModelHFCheckpointer, ModelType
checkpoint_dir = '/data/users/kartikayk/cpts/Llama-2-13b-hf/'
safetensor_files = ['model-00001-of-00003.safetensors', 'model-00002-of-00003.safetensors', 'model-00003-of-00003.safetensors']
pytorch_files = ['pytorch_model-00001-of-00003.bin', 'pytorch_model-00002-of-00003.bin', 'pytorch_model-00003-of-00003.bin']
safetensor_sd = {}
for file in safetensor_files:
file_path = checkpoint_dir + file
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in f.keys():
safetensor_sd[key] = f.get_tensor(key)
# convert the state_dict from HF format to TorchTune format
# hf_to_tune needs to know some params for correct conversion
safetensor_sd_torchtune = convert_weights.hf_to_tune(safetensor_sd, num_heads=40, num_kv_heads=40, dim=5120)
# Use torchTune's HF Checkpointer to get the state_dict
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=checkpoint_dir,
checkpoint_files=pytorch_files,
output_dir='/data/users/kartikayk/cpts/Llama-2-13b-hf/',
model_type=ModelType.LLAMA2
)
torchtune_sd = checkpointer.load_checkpoint()
# assert that we get the same keys and values
# torchtune checkpointer adds an additional 'model' key
for key in torchtune_sd['model'].keys():
assert torch.equal(torchtune_sd['model'][key], safetensor_sd_torchtune[key])
And here's the output:
Given that these are numerically equivalent, I think the best way forward would be if you can add a flag to FullModelHFCheckpointer
- something like is_safetensor
and when this is True
, just use ths safetensor.save_file
instead of safe_torch_load
to get the state_dict. Everything else, including the conversion to TorchTune's format should be the same. This is the relevant function: https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_checkpointing/_checkpointer.py#L323
Does this make sense to you?
@kartikayk Thank you for your suggestion.
Hi, it seems have errors on my device when I set the output_pro.weight = Tok_embedding.weight
for Gemma. Is there any way to fix it?
[rank1]:[E ProcessGroupNCCL.cpp:523] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=21, OpType=BROADCAST, NumelIn=524290048, NumelOut=524290048, Timeout(ms)=600000) ran for 600431 milliseconds before timing out.
Seems like this is actively being discussed on the discord. Once the discussion is over, we can come back and summarize it here.
cc: @ebsmothers
Hi, it seems have errors on my device when I set the
output_pro.weight = Tok_embedding.weight
for Gemma. Is there any way to fix it?[rank1]:[E ProcessGroupNCCL.cpp:523] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=21, OpType=BROADCAST, NumelIn=524290048, NumelOut=524290048, Timeout(ms)=600000) ran for 600431 milliseconds before timing out.
Thanks for the discussion, there is a temporary solution: remove any weight tying that occurs before FSDP wrapping and put weight tying here. https://github.com/pytorch/torchtune/blob/73647e26eb327c1e7dff6d6d12e4060c16c11da9/recipes/full_finetune_distributed.py#L261
Yeah to summarize the discussion on Discord: when training with FSDP the way we initialize the model undoes the weight tying. Specifically I suspect it's because we initialize on meta device. Not only that, but we cannot tie weights prior to FSDP wrapping or else we will hit a hang at our first sync point. You can see e.g. here for some discussion on the topic.
We can get around this by instead tying weights after FSDP wrapping. I believe @solitude-alive already has a tie_weight
utility defined on their fork, we just need to call this in the recipe instead of the model builder. This way we can control when it gets executed; we can execute it anytime in our single device recipes, but need to execute it after FSDP wrapping in our distributed recipes. (Another option would be some kind of post-init hook but not sure offhand how to implement it.)
The Google model have 2B model, it seems that we can use less than 4*24GB GPUs to fine-tune with full parameters. Do you plan to support it?