Open tengomucho opened 1 month ago
I think it is the same issue as https://github.com/pytorch/xla/issues/6991 and I fixed in the nightly
it would be hard for us to update 2.3 at this point, but what we can do is in python layer to force the typing.
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
for example we can do a manual .to(torch.bool)
for is_done
.
The good news is that I tried the nightly version of torch_xla and it seems to work fine. the bad news is that the workaround does not work for 2.3. I can assign it to a variable before returning, print it out, but as soon as it returns it crashes. If you think about any other warkaround, that would be great.
I verified it on my end with 2.3
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device).to(torch.bool)
fixed the runtime error.
🐛 Bug
Whenever I use
generate
function on a TPU (I use v5e litepod8), I have a crash with a C++ stack trace but no info on the python side and no way to catch/recover.To Reproduce
It can be fairly easy to reproduce it, I used this script:
the script crashed and gives this output:
Expected behavior
I expect the code snippet to execute without any error.
Environment