Open pansershrek opened 1 week ago
If you apply fully_shard
to each transformer block and then to the root module, this should work for tied embedding and final linear. The root module will manage both.
I want to shard output embedding layer - I use same strategy as in Llama, but training stacked after first butch
ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, )
Do you want to train with 2D parallelism (FSDP + TP)? With TP only?
Are there any plans to support Gemma2 in the torchtitan? I tried to use torchtitan to finetune Gemma2 model, but stuck on the following problem: how to parallelize tied layer in Gemma2 model? Maybe somebody kwon the solution for this problem 😄