pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.46k stars 469 forks source link

Pytorch XLA appears to not be able to access TPU's on Kaggle #5774

Open VatsaDev opened 11 months ago

VatsaDev commented 11 months ago

❓ Questions and Help

Hi, was working on porting code to work with TPU's and the TRC, and was testing TPU VMs with kaggle

Working with https://github.com/pytorch/xla/blob/master/docs/pjrt.md, With Kaggle I consistently receive the error,

Traceback (most recent call last):
  File "sample-XLA.py", line 40, in <module>
    checkpoint_dict = torch.load(checkpoint, map_location=xm.xla_device())
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 1392, in persistent_load
    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 1366, in load_tensor
    wrap_storage=restore_location(storage, location),
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 1299, in restore_location
    return default_restore_location(storage, str(map_location))
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 381, in default_restore_location
    result = fn(storage, location)
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 304, in _hpu_deserialize
    assert hpu is not None, "HPU device module is not loaded"
AssertionError: HPU device module is not loaded

yet the xm xla device is xla:0 and I do use os.environ['PJRT_DEVICE'] = 'TPU'

Kaggle NB(https://www.kaggle.com/code/vatsadev/notebook5e5db4afa5)

Whats the issue?

JackCaoG commented 10 months ago

It seems like you are trying to load checkpoint from somwhere and it failed the check in

https://github.com/pytorch/pytorch/blob/16f82198ca081c9c4e4e4b7f45d759147f870318/torch/serialization.py#L294-L298

This is new to me as well, but is your checkpoint has anything to do with HPU?

VatsaDev commented 10 months ago

Its a pytorch ckpt.pt file, its trying to load the checkpoint into a TPU, the checkpoint is from andrej karpathys llama2c repo. I made a toy 5 million params, it works on his code, with CPU and GPU, and the sample-XLA.py file im using, CPU is working, XLA TPU is failing.