alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.05k stars 353 forks source link

Support PyTorch XLA as a frontend #945

Open Kelvin-Ng opened 1 year ago

Kelvin-Ng commented 1 year ago

System information

Describe the new feature and the current behavior/state

Most models in the transformers library, like LLaMa, do not have a Flax version. Alpa uses torch.fx.symbolic_trace to convert these models to graphs, but they have something dynamic and so it will not work.

Although the models have something dynamic, PyTorch XLA is actually able to run them. For example, here is the XLA IR for the LLaMa model from transformers: https://gist.github.com/Kelvin-Ng/6fa6ededead2a42a806a69fd4c932a3e

Supporting PyTorch XLA as a frontend will enable the use of many models from the transformers library.

Will this change the current API? How?

Describe alternatives you've considered

  1. Provide an API that allows manually supplying an XLA IR, so that I can just use PyTorch XLA to convert the models to XLA IR and then supply that into Alpa.

  2. Apparently torch.jit.trace also works. Is it possible for Alpa to accept output of torch.jit.trace instead of torch.fx.symbolic_trace?

Additional context

richardliaw commented 1 year ago

Hey @Kelvin-Ng , I think this makes a lot of sense. Does Alpa primarily need to accept the output of torch.jit.trace instead?

Kelvin-Ng commented 1 year ago

In fact I was proposing two alternatives: either use PyTorch XLA, or use torch.jit.trace.

I think using PyTorch XLA is a better option because it should be the most general -- anything that PyTorch can run will be supported. However, that may requires more modifications in Alpa, because I see that some pipeline parallelism stuff operates on JAX graph which we don't have for PyTorch XLA.

Or we can use torch.jit.trace (instead of symbolic_trace that Alpa currently uses). I guess we only need to modify the conversion code (in here I suppose: https://github.com/alpa-projects/alpa/blob/main/alpa/torch/nn/__init__.py#L22). It currently do conversion on the code generated by symbolic_trace, and we need to change it to support torch.jit.trace. However, I do not fully understand this piece of code, so I am not sure how to do that reliably, especially because the code generated by torch.jit.trace is more complicated than symbolic_trace according to the PyTorch documentation. However, the benefit of this approach is that no modification to the Alpa core code is necessary.

gjoliver commented 1 year ago

We should go the TorchXLA route if possible. I think the current torch.jit implementation was largely a one-off effort and will be hard to maintain / not as scalable.