Open carmocca opened 1 year ago
Since Neuron is based on XLA, wouldn't the TPU accelerated work out of the box with Trn1? xla libraries are already included in their AMI
Possibly, but we would need to add testing and some glue code to advertise proper support.
Has there been any movement or is there an estimated timeline?
AFAIK this already works out of the box. cc @awaelchli or @justusschock to confirm.
Should we close this issue?
I'd be curious if it's known to work with torch 2.0, which appears to still be in beta. My TPU devices show up in torch 1.3, but not torch 2.0.
torch-neuronx
only added beta support for PyTorch 2.0 in last week's release
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/index.html#latest-neuron-release
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/appnotes/torch-neuronx/introducing-pytorch-2-0.html#introduce-pytorch-2-0
Interestingly, it drops XRT support in favor of PjRT
My TPU devices
Did you mean Trainium devices?
torch-neuronx only added beta support for PyTorch 2.0 in last week's release
Ahh, my bad. I just started tinkering with neuron and was hoping to avoid re-factoring a few things. I'll just refactor for now. :)
Did you mean Trainium devices?
Yea, Trainium.
Thanks a lot for all of your hard work. This is rad tech!
@carmocca Generally it works yes, since it is all abstracted through XLA. I tested it back then with the neuron version that supported torch 1.13. We found some quirks though, one of them described here: #18851
With the latest neuron version:
aws-neuronx-runtime-discovery==2.9
libneuronxla==1.0.680
neuronx-cc==2.11.0.34+c5231f848
neuronx-distributed==0.5.0
neuronx-hwm==2.11.0.2+e34678757
torch-neuronx==2.0.0.2.0.1b0
torch-xla==2.0.0+torchneuron0
there's an issue with runniing lightning. while pytorch is aware of TPU devices:
#xla_test.py
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
outputs:
python xla_test.py
xla:0
2023-11-14 17:19:03.000563: 6081 INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-11-14 17:19:03.000565: 6081 INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.11.0.34+c5231f848/MODULE_9711136460714136094+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
tensor([[ 0.6607, -1.5726],
[-1.7505, -1.4779]], device='xla:0')
the following lightning code complains about instance metadata
#xla-lightning.py
from lightning.fabric.accelerators import XLAAccelerator
devices = XLAAccelerator.auto_device_count()
print(devices)
fails with
Traceback (most recent call last):
File "xla-lightning.py", line 2, in <module>
devices = XLAAccelerator.auto_device_count()
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/lightning/fabric/accelerators/xla.py", line 83, in auto_device_count
return device_count_on_version.get(tpu.version(), 8)
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch_xla/experimental/tpu.py", line 118, in version
env = get_tpu_env()
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch_xla/experimental/tpu.py", line 112, in get_tpu_env
metadata = _get_metadata('tpu-env')
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch_xla/experimental/tpu.py", line 67, in _get_metadata
resp = requests.get(path, headers={'Metadata-Flavor': 'Google'})
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/requests/api.py", line 73, in get
return request("get", url, params=params, **kwargs)
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/requests/api.py", line 59, in request
return session.request(method=method, url=url, **kwargs)
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/requests/sessions.py", line 589, in request
resp = self.send(prep, **send_kwargs)
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/requests/sessions.py", line 703, in send
r = adapter.send(request, **kwargs)
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/requests/adapters.py", line 519, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/attributes/tpu-env (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f5733fcfb50>: Failed to establish a new connection: [Errno -2] Name or service not known'))
both scripts work with torch-1.13:
(venv2) ubuntu@ip-10-0-3-46:~$ pip freeze | grep torch
pytorch-lightning==2.1.1
torch==1.13.1
torch-neuronx==1.13.1.1.12.1
torch-xla==1.13.1+torchneuronc
torchmetrics==1.2.0
(venv2) ubuntu@ip-10-0-3-46:~$ python xla_test.py
xla:1
2023-11-14 17:35:33.000636: 6632 INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-11-14 17:35:33.000637: 6632 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/ubuntu/neuroncc_compile_workdir/50f76e6a-8df1-4a5c-a6d5-06678791ae7a/model.MODULE_1308431556746240162+d41d8cd9.hlo.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/50f76e6a-8df1-4a5c-a6d5-06678791ae7a/model.MODULE_1308431556746240162+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
tensor([[ 0.6607, -1.5726],
[-1.7505, -1.4779]], device='xla:1')
(venv2) ubuntu@ip-10-0-3-46:~$ python xla-lightning.py
2
(venv2) ubuntu@ip-10-0-3-46:~$
from the latest aws neuron version: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/index.html#neuron-2-16-0-12-21-2023 "adds new support for PyTorch Lightning Trainer (beta) "
🚀 Feature
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html https://aws.amazon.com/machine-learning/neuron/
Motivation
https://aws.amazon.com/about-aws/whats-new/2022/10/ec2-trn1-instances-high-performance-cost-effective-deep-learning-training/
Pitch
Neuron is XLA based, so it'll probably be an accelerator and strategy. Their marketing materials also advertise that it supports FSDP, and Megatron-LM. It's unclear at this point whether we would have specialized Neuron-based implementations for those.
We would also need access to hardware for testing.
If you enjoy Lightning, check out our other projects! âš¡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.