Open Jiayi-Pan opened 2 days ago
Great question. I have a couple questions and a couple suggestions
Question
fuse_embedding
is dynamic, the shape of input_embeddings
is static? This would explain why llm
hlo is staticfuse_embedding
? For example are there a total 100 different shape combinations possible, or there can be literally thousands of different shape combinations possible.Suggestion
persistent caching
? If not please take a look at https://github.com/pytorch/xla/blob/master/API_GUIDE.md#compilation-caching. If there is a relatively smaller dynamism in your code enabling the persistent caching would fix the issue(you can compile and remember all possible combinations).For nightly you can try use
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
since last nightl's nightly seems to be broken.
Eager mode pretty much just compile op by op. It will compile each op once for each input shape, the overall compile time is usually lower. Let me know how above 2 suggestions work for you.
Thank you for the instructions!
Re Q1: that's correct! We deliberately pad both raw_image_tensor
and input_embeddings
to make the shape static. Only fuse_embedding
is recompiled while llm
and image_encoder
, where most of the compute happens are static.
Re Q2: unfortunately it's very dynamic, it should be at least on the OOM of thousands
The eager mode looks very promising, however, I'm unable to install the nightly
tpu-vm:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly': Expected end or semicolon (after name and no valid version specifier)
torch-xla==nightly
hmm that's weird, can you access https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
through? If I click on this link it just download the whl file for me. Also is your python version 3.10?
I can access it, and it's 3.10. But the issue is still there
jiayipan@t1v-n-f6802337-w-0:~$ python --version
Python 3.10.12
jiayipan@t1v-n-f6802337-w-0:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly+20240701': Expected end or semicolon (after name and no valid version specifier)
torch-xla==nightly+20240701
^
jiayipan@t1v-n-f6802337-w-0:~$ wget https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
--2024-07-03 17:29:57-- https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.12.207, 173.194.217.207, 74.125.26.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.12.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 83362771 (80M) [application/octet-stream]
Saving to: ‘torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl’
torch_xla-nightly+2 100%[===================>] 79.50M 88.9MB/s in 0.9s
2024-07-03 17:29:58 (88.9 MB/s) - ‘torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl’ saved [83362771/83362771]
jiayipan@t1v-n-f6802337-w-0:~$ pip install
.bash_history
.bash_logout
.bashrc
.cache/
.config/
.local/
.profile
.ssh/
.viminfo
buckets/
prismatic-video-lms/
torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
torch_xla-nightly-cp310-cp310-linux_x86_64.whl
jiayipan@t1v-n-f6802337-w-0:~$ pip install torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly+20240701': Expected end or semicolon (after name and no valid version specifier)
torch-xla==nightly+20240701
^
jiayipan@t1v-n-f6802337-w-0:~$
hmm I can't repo this, which is a bit wierd. Maybe manually renamed the whl? something like
mv torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl torch_xla-nightly-cp310-cp310-linux_x86_64.whl
Some updates.
On reproducing the installation issue It turns out that the installation error only happens after
python3 -m pip install --upgrade pip
Given a clean tpu-v3
vm w/ ubuntu-22.04
, you should be able to reproduce the error by
python3 -m pip install --upgrade pip
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
On Eager Mode I tried eager mode! The code structure is basically as shown here.
for ... # loading data
# these tensors are with static shape, xla works great on them
image_embeddings = image_encoder(raw_image_tensor)
text_embeddings = get_text_embedding(text_token_idxs)
xm.mark_step()
# this part is very light in compute, but dynamic. We currently just recompile this graph every single time :(
torch_xla.experimental.eager_mode(True)
input_embeddings = fuse_embedding(raw_image_tensor, text_token_idxs, sequence_info_dict)
torch_xla.experimental.eager_mode(False)
xm.mark_step()
# these tensors are with static shape, xla works great on them
output_logits = llm(input_embeddings)
# loss compute / backward / optimizer step omited
Unfortunately, the code hangs and never reaches output_logits = llm(input_embeddings)
. (It still works fine on nightly when I disable eager mode).
Do you have any suggestions on debugging? There are a few mark_steps
around/within fuse_embedding
, not sure if they cause any trouble
❓ Questions and Help
We are using Pytorch XLA w/ TPU to train a multi-modal language models.
We can make most of the code, such as image encoding and the forward pass in the LLM backbone, in a static shape, which XLA handles well. However, making the part that fuses image and text embeddings into the input embedding static is extremely challenging.
Currently, we use
mark_step
to isolate that section from the rest of the code, allowing it to recompile each time. Although this part is very computationally light, the recompilation is extremely slow and often consumes the majority of training time.We find documentation on this issue very hard to find, and we are exploring better solutions, such as running that part on the CPU, in eager mode, or not saving that part of the graph to avoid OOM errors during long training runs. We wonder if you have any suggestions/pointers on how to workaround this inefficiency?
Following is a pesudo code to illustrate our problem