xrsrke / pipegoose

Large scale 4D parallelism pre-training for 🤗 transformers in Mixture of Experts *(still work in progress)*
MIT License
75 stars 17 forks source link

Kernel Fusion using torch.jit #10

Open xrsrke opened 9 months ago

xrsrke commented 9 months ago

Fuse some popular functions and automatically replace modules in an existing 🤗 transformers model with their corresponding fusion module

APIs

from pipegoose.nn import fusion

# and other parallelism ...
model = TensorParallel(model, parallel_context).parallelize()
model.fuse()

# or selective kernel fusion
model.fuse([fusion.LayerNorm, fusion.Attention])

TODOs

Reading (could be ignored)

sami-bg commented 8 months ago

Can you assign this to me? I'd like to give this a shot

sami-bg commented 8 months ago

https://github.com/xrsrke/pipegoose/pull/36

xrsrke commented 8 months ago

@sami-bg Check out torch.fx. We could use it to detect modules in a transformers model that can be fused and replace them with the fused version: model.transformers.blocks[0].dropout = fused_dropout

But we don't do it manually. Check out this tutorial for how to use torch.fx: https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html