pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.38k stars 427 forks source link

Core dump on TPU using transformers' generate #7122

Open tengomucho opened 1 month ago

tengomucho commented 1 month ago

🐛 Bug

Whenever I use generate function on a TPU (I use v5e litepod8), I have a crash with a C++ stack trace but no info on the python side and no way to catch/recover.

To Reproduce

It can be fairly easy to reproduce it, I used this script:

import os
os.environ["PJRT_DEVICE"] = "TPU"
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import torch_xla.core.xla_model as xm

model_id = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id)
model = model.to(xm.xla_device())

text = "A good idea will"
inputs = tokenizer(text, return_tensors="pt").to("xla")

# This will make system crash
outputs = model.generate(**inputs, max_new_tokens=20, do_sample=True)

the script crashed and gives this output:

F0000 00:00:1716818698.460641   95567 debug_macros.h:20] Non-OK-status: status.status() status: INVALID_ARGUMENT: Expected pred or integral type in argument to and/or operation; got F32.
*** Begin stack trace ***
        tsl::CurrentStackTrace()
        xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230802::StatusOr<xla::Shape const*>&&)
        torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)
        torch_xla::XlaHelpers::TypeOfXlaOp(xla::XlaOp)
        torch_xla::XlaHelpers::PromotedBinaryOp(xla::XlaOp, xla::XlaOp, std::function<xla::XlaOp (xla::XlaOp, xla::XlaOp)> const&)

        torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)

        torch_xla::BitwiseOrTensorOutputShape(torch::lazy::Value const&, torch::lazy::Value const&)
        std::_Function_handler<xla::Shape (), torch_xla::BitwiseOrTensor::BitwiseOrTensor(torch::lazy::Value const&, torch::lazy::Value const&)::{lambda()#1}>::_M_invoke(std::_Any_data const&)
        torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
        torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
        torch_xla::tensor_methods::bitwise_or(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
        torch_xla::XLANativeFunctions::bitwise_or(at::Tensor const&, at::Tensor const&)

        c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const

        at::_ops::bitwise_or_Tensor::call(at::Tensor const&, at::Tensor const&)
        at::native::__or__(at::Tensor const&, at::Tensor const&)

        at::_ops::__or___Tensor::call(at::Tensor const&, at::Tensor const&)

        _PyEval_EvalFrameDefault
        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        PyObject_Call
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault

        PyEval_EvalCode

        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain

        __libc_start_main
        _start
*** End stack trace ***

*** Check failure stack trace: ***
    @     0x7f2563b6a5d9  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f255bd1b834  ConsumeValue<>()
    @     0x7f255bd1b89e  torch_xla::ShapeHelper::ShapeOfXlaOp()
    @     0x7f255b9965c9  torch_xla::XlaHelpers::TypeOfXlaOp()
    @     0x7f255b99ef67  torch_xla::XlaHelpers::PromotedBinaryOp()
    @     0x7f255bcd31c1  std::_Function_handler<>::_M_invoke()
    @     0x7f255bc8eba9  torch_xla::InferOutputShape()
    @     0x7f255bcd9373  torch_xla::(anonymous namespace)::InferBinaryOpShape()
    @     0x7f255bcd956e  torch_xla::BitwiseOrTensorOutputShape()
    @     0x7f255b9e8119  std::_Function_handler<>::_M_invoke()
    @     0x7f255bd0f2a6  torch_xla::XlaNode::GetOpShape()
    @     0x7f255bd0fb99  torch_xla::XlaNode::XlaNode()
    @     0x7f255b9fbeaf  torch_xla::tensor_methods::bitwise_or()
    @     0x7f255b932101  torch_xla::XLANativeFunctions::bitwise_or()
    @     0x7f255bb9a419  c10::impl::make_boxed_from_unboxed_functor<>::call()
    @     0x7f2613e59028  c10::Dispatcher::callBoxed()
    @     0x55b4b5977ad0  (unknown)
https://symbolize.stripped_domain/r/?trace=7f26b78969fc,7f26b784251f&map= 
*** SIGABRT received by PID 95567 (TID 95567) on cpu 56 from PID 95567; stack trace: ***
PC: @     0x7f26b78969fc  (unknown)  pthread_kill
    @     0x7f24c4a5e441        944  (unknown)
    @     0x7f26b7842520  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f26b78969fc,7f24c4a5e440,7f26b784251f&map= 
E0527 14:04:58.661783   95567 coredump_hook.cc:470] RAW: Remote crash data gathering hook invoked.
E0527 14:04:58.661795   95567 coredump_hook.cc:509] RAW: Skipping coredump since rlimit was 0 at process start.
E0527 14:04:58.661802   95567 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0527 14:04:58.661807   95567 coredump_hook.cc:565] RAW: Sending fingerprint to remote end.
E0527 14:04:58.661820   95567 coredump_hook.cc:574] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0527 14:04:58.661826   95567 coredump_hook.cc:626] RAW: Dumping core locally.
E0527 14:04:58.884137   95567 process_state.cc:799] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

Expected behavior

I expect the code snippet to execute without any error.

Environment

JackCaoG commented 1 month ago

I think it is the same issue as https://github.com/pytorch/xla/issues/6991 and I fixed in the nightly

JackCaoG commented 1 month ago

it would be hard for us to update 2.3 at this point, but what we can do is in python layer to force the typing.

is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)

for example we can do a manual .to(torch.bool) for is_done.

tengomucho commented 1 month ago

The good news is that I tried the nightly version of torch_xla and it seems to work fine. the bad news is that the workaround does not work for 2.3. I can assign it to a variable before returning, print it out, but as soon as it returns it crashes. If you think about any other warkaround, that would be great.

JackCaoG commented 1 month ago

I verified it on my end with 2.3

is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device).to(torch.bool)

fixed the runtime error.