pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.39k stars 174 forks source link

Support Gemma2 in torchtitan #594

Open pansershrek opened 1 week ago

pansershrek commented 1 week ago

Are there any plans to support Gemma2 in the torchtitan? I tried to use torchtitan to finetune Gemma2 model, but stuck on the following problem: how to parallelize tied layer in Gemma2 model? Maybe somebody kwon the solution for this problem 😄

awgu commented 1 week ago

If you apply fully_shard to each transformer block and then to the root module, this should work for tied embedding and final linear. The root module will manage both.

pansershrek commented 1 week ago

I want to shard output embedding layer - I use same strategy as in Llama, but training stacked after first butch ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, )

awgu commented 1 week ago

Do you want to train with 2D parallelism (FSDP + TP)? With TP only?