pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.38k stars 427 forks source link

How to avoid compilation in a section of code? #7622

Open Jiayi-Pan opened 2 days ago

Jiayi-Pan commented 2 days ago

❓ Questions and Help

We are using Pytorch XLA w/ TPU to train a multi-modal language models.

We can make most of the code, such as image encoding and the forward pass in the LLM backbone, in a static shape, which XLA handles well. However, making the part that fuses image and text embeddings into the input embedding static is extremely challenging.

Currently, we use mark_step to isolate that section from the rest of the code, allowing it to recompile each time. Although this part is very computationally light, the recompilation is extremely slow and often consumes the majority of training time.

We find documentation on this issue very hard to find, and we are exploring better solutions, such as running that part on the CPU, in eager mode, or not saving that part of the graph to avoid OOM errors during long training runs. We wonder if you have any suggestions/pointers on how to workaround this inefficiency?

Following is a pesudo code to illustrate our problem

for ... # loading data
  # these tensors are with static shape, xla works great on them
  image_embeddings = image_encoder(raw_image_tensor)
  text_embeddings = get_text_embedding(text_token_idxs)

  xm.mark_step()
  # this part is very light in compute, but dynamic. We currently just recompile this graph every single time :(
  input_embeddings = fuse_embedding(raw_image_tensor, text_token_idxs, sequence_info_dict)
  xm.mark_step()

  # these tensors are with static shape, xla works great on them
  output_logits = llm(input_embeddings)
  # loss compute / backward / optimizer step omited
JackCaoG commented 2 days ago

Great question. I have a couple questions and a couple suggestions

Question

  1. seems like even through fuse_embedding is dynamic, the shape of input_embeddings is static? This would explain why llm hlo is static
  2. How dynamic is the fuse_embedding? For example are there a total 100 different shape combinations possible, or there can be literally thousands of different shape combinations possible.

Suggestion

  1. Have you used persistent caching? If not please take a look at https://github.com/pytorch/xla/blob/master/API_GUIDE.md#compilation-caching. If there is a relatively smaller dynamism in your code enabling the persistent caching would fix the issue(you can compile and remember all possible combinations).
  2. Maybe try eager mode. This is an experimental feature so you will need nightly. Take a look at https://github.com/pytorch/xla/blob/master/examples/eager/train_decoder_only_eager.py#L10. You can enable the eager mode in the dynamic region and disable it right after. Or you can do similar to https://github.com/pytorch/xla/blob/master/examples/eager/train_decoder_only_eager_with_compile.py which will turn on eager by default and manully pick the region to compile. Eager + compile is the UX I want to make default in next year so appreciate if you have any feedback.

For nightly you can try use

pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

since last nightl's nightly seems to be broken.

Eager mode pretty much just compile op by op. It will compile each op once for each input shape, the overall compile time is usually lower. Let me know how above 2 suggestions work for you.

Jiayi-Pan commented 2 days ago

Thank you for the instructions! Re Q1: that's correct! We deliberately pad both raw_image_tensor and input_embeddings to make the shape static. Only fuse_embedding is recompiled while llm and image_encoder, where most of the compute happens are static. Re Q2: unfortunately it's very dynamic, it should be at least on the OOM of thousands

The eager mode looks very promising, however, I'm unable to install the nightly

tpu-vm:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly': Expected end or semicolon (after name and no valid version specifier)
    torch-xla==nightly
JackCaoG commented 2 days ago

hmm that's weird, can you access https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl through? If I click on this link it just download the whl file for me. Also is your python version 3.10?

Jiayi-Pan commented 2 days ago

I can access it, and it's 3.10. But the issue is still there

jiayipan@t1v-n-f6802337-w-0:~$ python --version
Python 3.10.12
jiayipan@t1v-n-f6802337-w-0:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly+20240701': Expected end or semicolon (after name and no valid version specifier)
    torch-xla==nightly+20240701
             ^
jiayipan@t1v-n-f6802337-w-0:~$ wget https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
--2024-07-03 17:29:57--  https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.12.207, 173.194.217.207, 74.125.26.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.12.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 83362771 (80M) [application/octet-stream]
Saving to: ‘torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl’

torch_xla-nightly+2 100%[===================>]  79.50M  88.9MB/s    in 0.9s

2024-07-03 17:29:58 (88.9 MB/s) - ‘torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl’ saved [83362771/83362771]

jiayipan@t1v-n-f6802337-w-0:~$ pip install
.bash_history
.bash_logout
.bashrc
.cache/
.config/
.local/
.profile
.ssh/
.viminfo
buckets/
prismatic-video-lms/
torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
torch_xla-nightly-cp310-cp310-linux_x86_64.whl
jiayipan@t1v-n-f6802337-w-0:~$ pip install torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly+20240701': Expected end or semicolon (after name and no valid version specifier)
    torch-xla==nightly+20240701
             ^
jiayipan@t1v-n-f6802337-w-0:~$
JackCaoG commented 2 days ago

hmm I can't repo this, which is a bit wierd. Maybe manually renamed the whl? something like

mv torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl torch_xla-nightly-cp310-cp310-linux_x86_64.whl
Jiayi-Pan commented 1 day ago

Some updates.

On reproducing the installation issue It turns out that the installation error only happens after

python3 -m pip install --upgrade pip 

Given a clean tpu-v3 vm w/ ubuntu-22.04, you should be able to reproduce the error by

python3 -m pip install --upgrade pip 
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Jiayi-Pan commented 1 day ago

On Eager Mode I tried eager mode! The code structure is basically as shown here.

for ... # loading data
  # these tensors are with static shape, xla works great on them
  image_embeddings = image_encoder(raw_image_tensor)
  text_embeddings = get_text_embedding(text_token_idxs)

  xm.mark_step()
  # this part is very light in compute, but dynamic. We currently just recompile this graph every single time :(
  torch_xla.experimental.eager_mode(True)
  input_embeddings = fuse_embedding(raw_image_tensor, text_token_idxs, sequence_info_dict)
  torch_xla.experimental.eager_mode(False)
  xm.mark_step()

  # these tensors are with static shape, xla works great on them
  output_logits = llm(input_embeddings)
  # loss compute / backward / optimizer step omited

Unfortunately, the code hangs and never reaches output_logits = llm(input_embeddings). (It still works fine on nightly when I disable eager mode). Do you have any suggestions on debugging? There are a few mark_steps around/within fuse_embedding, not sure if they cause any trouble