microsoft / onnxscript

ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python.
https://onnxscript.ai/
MIT License
238 stars 43 forks source link

Support for collectives / MPI primitives #1309

Open vincent-mayer opened 2 months ago

vincent-mayer commented 2 months ago

Hi,

great work around onnxscript! I was wondering whether somewhere down the road collective operations / MPI primitives like reduce_scatter, all_gather will be added? If not, I'd be very curious about the reasoning.

Thanks!

justinchuby commented 2 months ago

Collective ops are being discussed in ONNX IIRC. Could you explain how you plan to use them or how they should appear in the graph? Thanks!

vincent-mayer commented 2 months ago

Thanks for the response @justinchuby! I saw this onnx issue https://github.com/microsoft/onnxruntime/issues/8244 and this talk from an Nvidia engineer, but not that much seems to have happened addressing the issues mentioned in the talk.

StableHLO another portability layer similar to onnx has added those operations, see here. Essentially we would like to be able to express model sharding and where collective operations have to be inserted into the graph on the python-level and export those annotations/ops via the onnx IR to later consume them in the compiler. For an example usage in python see tensorrt-llm here.

justinchuby commented 2 months ago

cc @gramalingam if you can share more info