pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 467 forks source link

Exception on TPU when compiling gemma 2b #6711

Closed tengomucho closed 6 months ago

tengomucho commented 6 months ago

🐛 Bug

I tried to run an inference on google/gemma-2b, and when I compile the model I get an exception.

To Reproduce

I run the script on a TPU V5e-litepod8 in here: https://gist.github.com/tengomucho/76fb3d630ac4a99c7f1f5e654700bb60.

Steps to reproduce the behavior:

DBG_COMPILE=1 python ./static_cache_test.py

Here's a stack trace I get:

Traceback (most recent call last):
  File "/home/amoran/optimum-tpu/alvaro/static_cache_test.py", line 114, in <module>
    next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position)
  File "/home/amoran/optimum-tpu/alvaro/static_cache_test.py", line 34, in decode_one_tokens
    logits = model(
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1019, in forward
    @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 49, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 571, in extract_compiled_graph
    extract_internal(fused_module), node.args, None)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 338, in extract_internal
    xm.mark_step()
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 891, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: ./torch_xla/csrc/runtime/pjrt_computation_client.h:146 : Check failed: HasValue() 
*** Begin stack trace ***
        tsl::CurrentStackTrace()
        torch_xla::runtime::PjRtComputationClient::PjRtData::GetHandle()
        torch::lazy::LazyGraphExecutor::RunPostOrder(std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&, torch::lazy::LazyGraphExecutor::SyncTensorCollection*)
        torch_xla::XLAGraphExecutor::RunPostOrder(std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&, torch::lazy::LazyGraphExecutor::SyncTensorCollection*)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, bool, bool, bool)
        torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::string>, bool)

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        PyObject_Call
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        PyEval_EvalCode

        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain

        __libc_start_main
        _start
*** End stack trace ***
buffer with shape bf16[2,1,1024,256] on device TPU:0 is null

Expected behavior

Script running untli the end, printing timings and results.

Environment

JackCaoG commented 6 months ago

Hmm this seems like a real error with dynamo. @alanwaketan Do you know who benchmarks Gemma inference with dynamo?

alanwaketan commented 6 months ago

@JackCaoG No, I don't think there is anyone working on it at this moment.

JackCaoG commented 6 months ago

what's the error you got when running with nightly?

tengomucho commented 6 months ago

I just re-tried with the nightly docker image (sha256:9c517c2514540d373cbb6d06333df144df1a3099626704558458da4cdf49adf6). With or without compilation, I get this same error:

Traceback (most recent call last):
  File "/workspace/alvaro/static_cache_test.py", line 117, in <module>
    next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position)
  File "/workspace/alvaro/static_cache_test.py", line 32, in decode_one_tokens
    logits = model(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1025, in forward
    @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 51, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 572, in extract_compiled_graph
    collector.run(*xla_args)
  File "/usr/local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 460, in run_node
    result = super().run_node(n)
  File "/usr/local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 319, in call_module
    return submod(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/functional.py", line 2264, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: torch_xla/csrc/tensor_ops.cpp:248 : Check failed: indices->dtype() == at::ScalarType::Long (Int vs. Long)
*** Begin stack trace ***
        tsl::CurrentStackTrace()
        torch_xla::tensor_ops::Embedding(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
        torch_xla::tensor_methods::embedding(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
        torch_xla::XLANativeFunctions::embedding_symint(at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)

        at::_ops::embedding::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)

        at::_ops::embedding::call(at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        _PyObject_Call
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        PyVectorcall_Call
        _PyEval_EvalFrameDefault

        PyVectorcall_Call
        _PyEval_EvalFrameDefault

        PyVectorcall_Call
        _PyEval_EvalFrameDefault

        PyVectorcall_Call
        _PyEval_EvalFrameDefault

        PyVectorcall_Call
        _PyEval_EvalFrameDefault

        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        PyEval_EvalCode

        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

The only part that is different then is the last message in the error. With compilation on, I get this at the end:

While executing %hidden_states : [num_users=1] = call_module[target=L__self___model_embed_tokens](args = (%l_input_ids_,), kwargs = {})
Original traceback:
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1073, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 869, in forward
    inputs_embeds = self.embed_tokens(input_ids)
alanwaketan commented 6 months ago

@ManfeiBai Can you take a look?

JackCaoG commented 6 months ago

well Check failed: indices->dtype() == at::ScalarType::Long (Int vs. Long) seems to be easy to resolve, it seems like we are trying to index with int tensor but XLA expect it to be Long

JackCaoG commented 6 months ago

It is from https://github.com/pytorch/xla/blob/bf60b6233de8b84eed2b666c1951786ab26294fa/torch_xla/csrc/tensor_ops.cpp#L195-L199

One easy fix is to just convert the index to Long if it is also a int type(I don't think there is an int128 so it should safe to convert as long as it is int)

JackCaoG commented 6 months ago

wait, I think it is from https://github.com/pytorch/xla/blob/bf60b6233de8b84eed2b666c1951786ab26294fa/torch_xla/csrc/tensor_ops.cpp#L245-L248

The check already includes at::kInt. It is fixed by https://github.com/pytorch/xla/pull/6718 that merged today. If you wait until tmr it should get fixed..

ManfeiBai commented 6 months ago

@ManfeiBai Can you take a look?

Hi, Thanks, will test locally with torch_xla built with https://github.com/pytorch/xla/pull/6718 to confirm that, do we have repos or commands to repro first?

tengomucho commented 6 months ago

So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.

alanwaketan commented 6 months ago

@ManfeiBai Can you try reproducing it?

ManfeiBai commented 6 months ago

@ManfeiBai Can you try reproducing it?

sure, will do

ManfeiBai commented 6 months ago

So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.

Hi, @tengomucho, do we have any link or repo that I could pull to my local device to repro this failure locally too?

alanwaketan commented 6 months ago

So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.

Hi, @tengomucho, do we have any link or repo that I could pull to my local device to repro this failure locally too?

@ManfeiBai I think @tengomucho linked the script in the description.

alanwaketan commented 6 months ago

@JackCaoG Basing on the latest reply from @tengomucho, it seems like a dynamo issue. Can you take a look as well?

JackCaoG commented 6 months ago

I most likely won't have cycle until this Friday, will try to take a look this Friday.

ManfeiBai commented 6 months ago

So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.

Hi, @tengomucho, do we have any link or repo that I could pull to my local device to repro this failure locally too?

@ManfeiBai I think @tengomucho linked the script in the description.

Thanks, synced with @alanwaketan, reproduced locally: https://gist.github.com/ManfeiBai/9ed8b9790fe849d92df622653b398035

with DBG_COMPILE=1, I saw this error:

RuntimeError: ./torch_xla/csrc/runtime/pjrt_computation_client.h:153 : Check failed: HasValue() 

without DBG_COMPILE=1, program finished

JackCaoG commented 6 months ago

yea dynamo is only enabled if DBG_COMPILE=1, this is aligned with @tengomucho 's obseration. One thing to try is to use openxla instead of openxla_eval. Openxla backend will run the aot-autograd and atenify the ops, not sure if this will make any difference to the assert error above.

tengomucho commented 6 months ago

Hey @JackCaoG you are right, setting the compilation backend to openxla made the error disappear!

JackCaoG commented 6 months ago

@tengomucho yay. We should consider removing the openxla_eval backend as openxla seems more mature. Let me check with team regarding the performance difference between both.

alanwaketan commented 6 months ago

@JackCaoG We use openxla_eval by default in most of examples. lol We can re-benchmark it to see if the performance gaps are gone.

JackCaoG commented 6 months ago

@alanwaketan sounds good, based on the torchbench result we got recently. openxla has higher passing rate and similar performance compared to the openxla_eval

JackCaoG commented 6 months ago

@tengomucho I am going to close this issue for now, feel free to open a new one if you run into other issues,