Open radna0 opened 2 months ago
@ManfeiBai can you take a look?
thanks, sure, will take a look
Hi, tried on v4-32 locally, due to v4-32 is multi host device, used commands to all workers like: https://gist.github.com/ManfeiBai/3a2ac89435dbb7a9914e34d24b8449ba
code above finished and printed: https://gist.github.com/ManfeiBai/cb23bb15850c8167320e910cf4b3f95c
Hi, @radna0, would you mind try commands like https://gist.github.com/ManfeiBai/3a2ac89435dbb7a9914e34d24b8449ba on your local v4-32 again? or would you mind share your commands so that I could try to reproduce locally again too, please let us know if any updates
Here's what I got running the script @ManfeiBai tpu_v4_logs.txt
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
devices = xm.get_xla_supported_devices() print(f"Devices: {devices}") total = { 0: 0, 1: 0 } for device in devices: mem = round(xm.get_memory_info(device)["bytes_limit"] / 1e9, 2) total[1] += mem print(f'Total TPU device: {device} memory: {mem} GB')
print(f"Total TPU memory: {total[0]} / {total[1]} GB")
for device in devices: mem = round(xm.get_memory_info(device)["bytes_limit"] / 1e9, 2) t = torch.randn(torch.randint(1, 8, (1,)), 4, 144, 720, 1280).to(device) mem_used = round(xm.get_memory_info(device)["bytes_used"] / 1e9, 2) total[0] += mem_used print(f'Total TPU device: {device} memory: {mem_used} / {mem} GB') xm.mark_step()
print(f"Total TPU memory: {total[0]} / {total[1]} GB")
Traceback (most recent call last): File "/home/kojoe/mem_check.py", line 6, in
devices = xm.get_xla_supported_devices()
File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 93, in get_xla_supported_devices
devices = torch_xla._XLAC._xla_get_devices()
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology.