aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
450 stars 152 forks source link

Compilation error when chaningthe hidden size of a PyTorch linear model. #768

Open woshiyyya opened 1 year ago

woshiyyya commented 1 year ago

I am trying to run the example in this user guide on a trn1.2xlarge instance:

The repro code is as below. The script works when HIDDEN_SIZE=10 but fails when HIDDEN_SIZE=1000.

torchrun --nproc_per_node=2 test.py
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.xla_backend

torch.distributed.init_process_group('xla')

HIDDEN_SIZE = 10
# HIDDEN_SIZE = 1000

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net1 = nn.Linear(10, HIDDEN_SIZE)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(HIDDEN_SIZE, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def train_fn():
    device = xm.xla_device()
    rank = xm.get_ordinal()

    # Create the model and move to device
    model = Model().to(device)
    ddp_model = DDP(model, gradient_as_bucket_view=True)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    num_iteration = 100
    for step in range(num_iteration):
        optimizer.zero_grad()
        outputs = ddp_model(torch.randn(20, 10).to(device))
        labels = torch.randn(20, 5).to(device)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        xm.mark_step()
        if rank == 0:
            print(f"Loss after step {step}: {loss.cpu()}")

train_fn()

The error messages:

2023-10-18 12:53:22.000940:  11779  ERROR ||NEURON_CC_WRAPPER||: Compilation failed for /tmp/neuroncc_compile_workdir/1b0d07d8-ed9c-4ee5-b34c-12bc774400ca/model.MODULE_7721032278850290508+d41d8cd9.hlo.pb after 0 retries.
Traceback (most recent call last):
  File "/mnt/cluster_storage/pypi/bin/neuron_cc_wrapper", line 8, in <module>
    sys.exit(main())
  File "/mnt/cluster_storage/pypi/lib/python3.8/site-packages/libneuronxla/neuron_cc_wrapper.py", line 297, in main
    neuron_xla_compile(args.input_file, " ".join(unparsed_args), args.output, cache_key=args.cache_key,
  File "/mnt/cluster_storage/pypi/lib/python3.8/site-packages/libneuronxla/neuron_cc_wrapper.py", line 267, in neuron_xla_compile
    ret = compile_with_cache(output, compile_cache, cache_key, execution_mode, 
  File "/mnt/cluster_storage/pypi/lib/python3.8/site-packages/libneuronxla/neuron_cc_wrapper.py", line 201, in compile_with_cache
    raise(e)
  File "/mnt/cluster_storage/pypi/lib/python3.8/site-packages/libneuronxla/neuron_cc_wrapper.py", line 181, in compile_with_cache
    ret = call_neuron_compiler(
  File "/mnt/cluster_storage/pypi/lib/python3.8/site-packages/libneuronxla/neuron_cc_wrapper.py", line 129, in call_neuron_compiler
    raise RuntimeError(f"Failed compilation with {cmd}: {res.stderr.decode()}")
RuntimeError: Failed compilation with ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/1b0d07d8-ed9c-4ee5-b34c-12bc774400ca/model.MODULE_7721032278850290508+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/1b0d07d8-ed9c-4ee5-b34c-12bc774400ca/model.MODULE_7721032278850290508+d41d8cd9.neff', '--verbose=35']: Non-output memory location with no reader {-t1273}@SB<0,0>(1x4)#Internal DebugInfo: <-t1273||UNDEF||[1, 1, 1]>
Non-output memory location with no reader {-t1274}@SB<0,0>(1x4)#Internal DebugInfo: <-t1274||UNDEF||[1, 1, 1]>
walrus_driver: /local/p4clients/pkgbuild-KZijL/workspace/src/KaenaCompiler/neuronxcc/walrus/deadcode_elim/src/dce.cpp:106: void neuronxcc::walrus::bad_instruction_check(bir::Module&, neuronxcc::walrus::Logger&): Assertion `false && "Reading undefined memloc"' failed.

2023-10-18T19:53:22Z [F134] neuronx-cc terminated abnormally - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new
2023-10-18T19:53:22Z Walrus driver failed to complete

2023-10-18 12:53:22.989502: W tensorflow/core/framework/op_kernel.cc:1830] OP_REQUIRES failed at tpu_execute_op.cc:266 : INTERNAL: neuronx-cc compilation failed.
2023-10-18 12:53:22.989790: W tensorflow/core/framework/op_kernel.cc:1830] OP_REQUIRES failed at tpu_execute_op.cc:266 : INTERNAL: neuronx-cc compilation failed.
[W logger.cpp:322] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())
2023-10-18 12:53:23.125978: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] StackTrace:
2023-10-18 12:53:23.126007: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] *** Begin stack trace ***
2023-10-18 12:53:23.126000: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] StackTrace:
2023-10-18 12:53:23.126014: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        tsl::CurrentStackTrace()
2023-10-18 12:53:23.126017: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] *** Begin stack trace ***
2023-10-18 12:53:23.126023: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        xla::util::ReportComputationError(tsl::Status const&, absl::lts_20220623::Span<xla::XlaComputation const* const>, absl::lts_20220623::Span<xla::Shape const* const>)
2023-10-18 12:53:23.126024: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        tsl::CurrentStackTrace()
2023-10-18 12:53:23.126032: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        xla::XrtComputationClient::ExecuteComputation(xla::ComputationClient::Computation const&, absl::lts_20220623::Span<std::shared_ptr<xla::ComputationClient::Data> const>, std::string const&, xla::ComputationClient::ExecuteComputationOptions const&)
2023-10-18 12:53:23.126032: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        xla::util::ReportComputationError(tsl::Status const&, absl::lts_20220623::Span<xla::XlaComputation const* const>, absl::lts_20220623::Span<xla::Shape const* const>)
2023-10-18 12:53:23.126039: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126042: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        xla::XrtComputationClient::ExecuteComputation(xla::ComputationClient::Computation const&, absl::lts_20220623::Span<std::shared_ptr<xla::ComputationClient::Data> const>, std::string const&, xla::ComputationClient::ExecuteComputationOptions const&)
2023-10-18 12:53:23.126047: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        xla::util::MultiWait::Complete(std::function<void ()> const&)
2023-10-18 12:53:23.126056: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126061: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126063: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        xla::util::MultiWait::Complete(std::function<void ()> const&)
2023-10-18 12:53:23.126068: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126070: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126073: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126074: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126080: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        clone
2023-10-18 12:53:23.126081: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126091: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] *** End stack trace ***
2023-10-18 12:53:23.126092: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]        clone
2023-10-18 12:53:23.126098: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126102: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] *** End stack trace ***
2023-10-18 12:53:23.126104: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] Status: INTERNAL: From /job:localservice/replica:0/task:0:
2023-10-18 12:53:23.126109: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 
2023-10-18 12:53:23.126109: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 2 root error(s) found.
2023-10-18 12:53:23.126112: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] Status: INTERNAL: From /job:localservice/replica:0/task:0:
2023-10-18 12:53:23.126117: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]   (0) INTERNAL: neuronx-cc compilation failed.
2023-10-18 12:53:23.126121: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 2 root error(s) found.
2023-10-18 12:53:23.126123: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]         [[{{node XRTExecute}}]]
2023-10-18 12:53:23.126126: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]   (0) INTERNAL: neuronx-cc compilation failed.
2023-10-18 12:53:23.126130: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]         [[XRTExecute_G12]]
2023-10-18 12:53:23.126134: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]         [[{{node XRTExecute}}]]
2023-10-18 12:53:23.126139: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]   (1) INTERNAL: neuronx-cc compilation failed.
2023-10-18 12:53:23.126144: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]         [[XRTExecute_G12]]
2023-10-18 12:53:23.126148: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]         [[{{node XRTExecute}}]]
2023-10-18 12:53:23.126152: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]   (1) INTERNAL: neuronx-cc compilation failed.
2023-10-18 12:53:23.126157: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 0 successful operations.
2023-10-18 12:53:23.126161: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]         [[{{node XRTExecute}}]]
2023-10-18 12:53:23.126163: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 0 derived errors ignored.
2023-10-18 12:53:23.126166: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 0 successful operations.
2023-10-18 12:53:23.126170: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] Recent warning and error logs:
2023-10-18 12:53:23.126171: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] 0 derived errors ignored.
2023-10-18 12:53:23.126178: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]   OP_REQUIRES failed at tpu_execute_op.cc:266 : INTERNAL: neuronx-cc compilation failed.
2023-10-18 12:53:23.126181: E tensorflow/compiler/xla/xla_client/xla_util.cc:90] Recent warning and error logs:
2023-10-18 12:53:23.126183: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]   OP_REQUIRES failed at tpu_execute_op.cc:266 : INTERNAL: neuronx-cc compilation failed.
2023-10-18 12:53:23.126189: E tensorflow/compiler/xla/xla_client/xla_util.cc:90]   OP_REQUIRES failed at tpu_execute_op.cc:266 : INTERNAL: neuronx-cc compilation failed.
Traceback (most recent call last):
  File "test-torchrun.py", line 46, in <module>
    train_fn()
  File "test-torchrun.py", line 36, in train_fn
    outputs = ddp_model(torch.randn(20, 10).to(device))
  File "/mnt/cluster_storage/pypi/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/cluster_storage/pypi/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1026, in forward
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
RuntimeError: INTERNAL: From /job:localservice/replica:0/task:0:
2 root error(s) found.
  (0) INTERNAL: neuronx-cc compilation failed.
         [[{{node XRTExecute}}]]
         [[XRTExecute_G12]]
  (1) INTERNAL: neuronx-cc compilation failed.
         [[{{node XRTExecute}}]]
0 successful operations.
0 derived errors ignored.
Recent warning and error logs:
  OP_REQUIRES failed at tpu_execute_op.cc:266 : INTERNAL: neuronx-cc compilation failed.
  OP_REQUIRES failed at tpu_execute_op.cc:266 : INTERNAL: neuronx-cc compilation failed.
Traceback (most recent call last):
  File "test-torchrun.py", line 46, in <module>
    train_fn()
  File "test-torchrun.py", line 43, in train_fn
    print(f"Loss after step {step}: {loss.cpu()}")
RuntimeError: INTERNAL: From /job:localservice/replica:0/task:0:
2 root error(s) found.
  (0) INTERNAL: neuronx-cc compilation failed.
         [[{{node XRTExecute}}]]
         [[XRTExecute_G12]]
  (1) INTERNAL: neuronx-cc compilation failed.
         [[{{node XRTExecute}}]]
0 successful operations.
0 derived errors ignored.
Recent warning and error logs:
  OP_REQUIRES failed at tpu_execute_op.cc:266 : INTERNAL: neuronx-cc compilation failed.

Here is the way I set up the environment:

# Configure Linux for Neuron repository updates
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com/ ${VERSION_CODENAME} main
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -

# Update OS packages 
sudo apt-get update -y

# Install Neuron Runtime 
sudo apt-get install aws-neuronx-collectives=2.* -y
sudo apt-get install aws-neuronx-runtime-lib=2.* -y

# Install Neuron Tools 
sudo apt-get install aws-neuronx-tools=2.* -y

# Install torch-neuronx
pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com/
pip install neuronx-cc==2.* torch-neuronx torchvision
micwade-aws commented 1 year ago

Thanks for filing the ticket - will repro and get back to you soon.