xrsrke / pipegoose

Large scale 4D parallelism pre-training for 🤗 transformers in Mixture of Experts *(still work in progress)*
MIT License
76 stars 17 forks source link

Automatic module mapping using torch.fx #40

Open xrsrke opened 9 months ago

xrsrke commented 9 months ago

Notes

APIs

model = AutoModel.from_pretrained(...)
ParallelMapping(model).is_mlp(module)

Write a function that:

giorgionicoli commented 9 months ago

I'm taking a stab at this issue. I'll first come up with a solution that works for at least bloom-560m, which is already mapped in ParallelMapping, such that we can see if the automatic mapping works. Then we can work on making it more general/generalizable.

yugen-ok commented 9 months ago

I'm working on this issue.

For testing, can I see some MWEs of what the input of these methods looks like? Like, typical inputs to make sure I get the right results.

xrsrke commented 9 months ago

@yugen-ok

model = AutoModel.from_pretrained()
parallel_mapping = ParallelMapping(model)
ref_mapping = {...}

for name, _ in model.named_modules():
   mapping[name] = parallel_mapping.is_row_parallel(name, module)

assert ref_mapping == mapping