Closed seanbenhur closed 1 year ago
As a note: We are looking to add better support for these encoder-decoder models using trace. We are working on performance for these kind of models.
For the issue: We tried reproducing the issue, but we didn't have access to your custom model. Hence, we used a t5-small to reproduce (changed the MODEL_PATH="t5-small"). We noticed that you are moving the model and inputs to xla device, that is not necessary and hence, we recommend removing that from this script. With this change, we went pass the compilation pass the compilation phase but ran into an assertion error related to the inputs. The error looks something like this:
File "/home/ec2-user/release-211/lib64/python3.7/site-packages/torch/jit/_trace.py", line 983, in trace_module
argument_names,
File "/home/ec2-user/release-211/lib64/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ec2-user/release-211/lib64/python3.7/site-packages/torch/nn/modules/module.py", line 1182, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/ec2-user/release-211/lib64/python3.7/site-packages/torch_neuronx/xla_impl/trace.py", line 85, in forward
tensors = self.flattener(tensors)
File "/home/ec2-user/release-211/lib64/python3.7/site-packages/torch_neuronx/xla_impl/structure.py", line 206, in __call__
assert self.layout == layout
AssertionError
This error happens because we are passing same inputs twice (this case is not supported at this moment during trace and we are working on fixing this). We recommend making the following changes:
model_neuron = torch_neuronx.trace(model,(tpu_input_ids,
tpu_mask,
tpu_input_ids.clone().detach()
))
After making these changes, we were able to compile but ran into a final error related to jit.trace. The error looks like this:
Traceback (most recent call last):
File "t5.py", line 40, in <module>
tpu_input_ids.clone().detach()
File "/home/ec2-user/release-211/lib64/python3.7/site-packages/torch_neuronx/xla_impl/trace.py", line 323, in trace
return torch.jit.trace(result, example_inputs, strict=False)
File "/home/ec2-user/release-211/lib64/python3.7/site-packages/torch/jit/_trace.py", line 768, in trace
_module_class,
File "/home/ec2-user/release-211/lib64/python3.7/site-packages/torch/jit/_trace.py", line 983, in trace_module
argument_names,
RuntimeError: Tracer cannot infer type of {'logits': tensor([[[0., 0., 0., ..., 0., 0., 0.],
This happens because HF T5 model outputs a mixture of tensor and dictionary of tensors https://github.com/huggingface/transformers/blob/v4.26.0/src/transformers/models/t5/modeling_t5.py#L1465 . This doesn't go well with torch.jit.trace as it expects the outputs to be of single type. If we can make a change to just output a tuple, something along this:
class Model(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, decoder_input_ids):
outputs = self.model(input_ids, attention_mask, decoder_input_ids)
return outputs.logits # Note: you can return any number here as output.logits, outputs.pass_key_values ..
model_wrapper = Model(model)
After making the above changes the model should compile and we should be able to get the results.
Here is the script that reaches this point:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import torch_neuronx
# MODEL_PATH = "/home/ubuntu/ipt_t5"
NEURON_PATH = "model_neuron_t5_v1.pt"
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
def encode(tokenizer, text, max_length=128, batch_size=1):
tokens = tokenizer.encode_plus(text,
max_length=max_length,
padding=True,
truncation=True,
pad_to_multiple_of = 32,
return_tensors = 'pt'
)
return tokens
device = 'xla'
sequence = "What a wonderful day!"
tokenized_sequence = encode(tokenizer,sequence)
print(tokenized_sequence)
tpu_input_ids = tokenized_sequence['input_ids']
tpu_mask = tokenized_sequence['attention_mask']
class Model(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, decoder_input_ids):
outputs = self.model(input_ids, attention_mask, decoder_input_ids)
return outputs.logits # Note: you can return any number here
model_wrapper = Model(model)
model_neuron = torch_neuronx.trace(model_wrapper,(tpu_input_ids,
tpu_mask,
tpu_input_ids.clone().detach()
))
#save as torchscript for inference
torch.jit.save(model_neuron,NEURON_PATH)
preds_cpu = model(input_ids=tokenized_sequence['input_ids'], attention_mask=tokenized_sequence['attention_mask'],
decoder_input_ids=tokenized_sequence['input_ids'])
preds_neuron = model_neuron(tokenized_sequence['input_ids'], tokenized_sequence['attention_mask'],
tokenized_sequence['input_ids'])
print("="*20)
print("PREDS CPU")
print(preds_cpu.logits)
print("="*20)
print("PREDS NEURON")
print(preds_neuron)
As a recommendation, since these encoder-decoder are autoregressive and since the encoder is run once and the decoder is run multiple times, encoder and decoder can be traced separately. We would run the encoder once and then run the decoder with padded inputs and run it till the stopping criteria is reached.
Thanks for the recommendation @aws-rhsoln , I tried the above code for my local model and t5-small
, but for both I got the following error
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: ***************************************************************
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: An Internal Compiler Error has occurred
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: ***************************************************************
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]:
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Error message: 1572864 requested and 1470400 written
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]:
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Error class: OSError
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Error location: Unknown
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Command line: /opt/aws_neuron_venv_pytorch/bin/neuronx-cc compile /tmp/tmp2m3bpw5z/model --framework XLA --target trn1 --output /tmp/tmp2m3bpw5z/graph.neff
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]:
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Internal details:
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/CommandDriver.py", line 259, in neuronxcc.driver.CommandDriver.CommandDriver.run
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/commands/CompileCommand.py", line 1089, in neuronxcc.driver.commands.CompileCommand.CompileCommand.run
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/commands/CompileCommand.py", line 1040, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/commands/CompileCommand.py", line 1065, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/commands/CompileCommand.py", line 1069, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Job.py", line 300, in neuronxcc.driver.Job.SingleInputJob.run
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Job.py", line 326, in neuronxcc.driver.Job.SingleInputJob.runOnState
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Pipeline.py", line 30, in neuronxcc.driver.Pipeline.Pipeline.runSingleInput
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Job.py", line 300, in neuronxcc.driver.Job.SingleInputJob.run
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Job.py", line 326, in neuronxcc.driver.Job.SingleInputJob.runOnState
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/Frontend.py", line 359, in neuronxcc.driver.jobs.Frontend.Frontend.runSingleInput
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/Frontend.py", line 171, in neuronxcc.driver.jobs.Frontend.Frontend.runXLAFrontend
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/support/Partitioning.py", line 308, in neuronxcc.driver.jobs.support.Partitioning.HhSubgraphs.write
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/support/Partitioning.py", line 183, in neuronxcc.driver.jobs.support.Partitioning.HhSubgraph.writeHhData
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/support/Partitioning.py", line 174, in neuronxcc.driver.jobs.support.Partitioning.HhSubgraph.dump_params
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "<__array_function__ internals>", line 180, in save
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/numpy/lib/npyio.py", line 519, in save
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: format.write_array(fid, arr, allow_pickle=allow_pickle,
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/numpy/lib/format.py", line 690, in write_array
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: array.tofile(fp)
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]:
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Version information:
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: NeuronX Compiler version 2.6.0.19+3d819e565
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]:
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Python version 3.8.10
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: HWM version 2.6.0.0-826e77395
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: NEFF version Dynamic
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: TVM version 1.15.0.0+0
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: NumPy version 1.22.4
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: MXNet not available
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]:
2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Artifacts stored in: /home/ubuntu/inferentia_t5/neuronxcc-z8ar2cxe
Traceback (most recent call last):
File "convert_to_neuronx.py", line 46, in <module>
model_neuron = torch_neuronx.trace(model_wrapper,(input_ids,
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch_neuronx/xla_impl/trace.py", line 309, in trace
neff_filename = hlo_compile(model_dir, compiler_workdir, compiler_args)
File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch_neuronx/xla_impl/trace.py", line 232, in hlo_compile
raise RuntimeError(f'neuronx-cc failed with {status}')
RuntimeError: neuronx-cc failed with 1```
We are able to compile t5-small with the above script and neuronx-cc=2.6.0.19+3d819e565. Can you share the torch-neuronx, torch-xla and transformers version you are using?
I am able to compile the model with the above neuronx-cc version, Thanks for the support! !
Thanks for the recommendation @aws-rhsoln , I tried the above code for my local model and
t5-small
, but for both I got the following errorTo disable this warning, you can either: - Avoid using `tokenizers` before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: *************************************************************** 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: An Internal Compiler Error has occurred 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: *************************************************************** 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Error message: 1572864 requested and 1470400 written 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Error class: OSError 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Error location: Unknown 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Command line: /opt/aws_neuron_venv_pytorch/bin/neuronx-cc compile /tmp/tmp2m3bpw5z/model --framework XLA --target trn1 --output /tmp/tmp2m3bpw5z/graph.neff 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Internal details: 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/CommandDriver.py", line 259, in neuronxcc.driver.CommandDriver.CommandDriver.run 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/commands/CompileCommand.py", line 1089, in neuronxcc.driver.commands.CompileCommand.CompileCommand.run 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/commands/CompileCommand.py", line 1040, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/commands/CompileCommand.py", line 1065, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/commands/CompileCommand.py", line 1069, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Job.py", line 300, in neuronxcc.driver.Job.SingleInputJob.run 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Job.py", line 326, in neuronxcc.driver.Job.SingleInputJob.runOnState 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Pipeline.py", line 30, in neuronxcc.driver.Pipeline.Pipeline.runSingleInput 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Job.py", line 300, in neuronxcc.driver.Job.SingleInputJob.run 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/Job.py", line 326, in neuronxcc.driver.Job.SingleInputJob.runOnState 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/Frontend.py", line 359, in neuronxcc.driver.jobs.Frontend.Frontend.runSingleInput 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/Frontend.py", line 171, in neuronxcc.driver.jobs.Frontend.Frontend.runXLAFrontend 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/support/Partitioning.py", line 308, in neuronxcc.driver.jobs.support.Partitioning.HhSubgraphs.write 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/support/Partitioning.py", line 183, in neuronxcc.driver.jobs.support.Partitioning.HhSubgraph.writeHhData 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "neuronxcc/driver/jobs/support/Partitioning.py", line 174, in neuronxcc.driver.jobs.support.Partitioning.HhSubgraph.dump_params 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "<__array_function__ internals>", line 180, in save 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/numpy/lib/npyio.py", line 519, in save 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: format.write_array(fid, arr, allow_pickle=allow_pickle, 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/numpy/lib/format.py", line 690, in write_array 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: array.tofile(fp) 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Version information: 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: NeuronX Compiler version 2.6.0.19+3d819e565 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Python version 3.8.10 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: HWM version 2.6.0.0-826e77395 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: NEFF version Dynamic 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: TVM version 1.15.0.0+0 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: NumPy version 1.22.4 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: MXNet not available 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: 2023-06-17T14:13:14Z ERROR 4884 [neuronx-cc]: Artifacts stored in: /home/ubuntu/inferentia_t5/neuronxcc-z8ar2cxe Traceback (most recent call last): File "convert_to_neuronx.py", line 46, in <module> model_neuron = torch_neuronx.trace(model_wrapper,(input_ids, File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch_neuronx/xla_impl/trace.py", line 309, in trace neff_filename = hlo_compile(model_dir, compiler_workdir, compiler_args) File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch_neuronx/xla_impl/trace.py", line 232, in hlo_compile raise RuntimeError(f'neuronx-cc failed with {status}') RuntimeError: neuronx-cc failed with 1```
sorry for piggy-backing on this issue. but I encountered a somewhat similar issue when trying to compile Donut model for Inferentia2. can you please let me know what was the fix applied. thanks.
Compiler status PASS ERROR *************************************************************** ERROR An Internal Compiler Error has occurred ERROR *************************************************************** ERROR ERROR Error message: simplifyStmtPredicates() takes exactly 1 positional argument (2 given) ERROR ERROR Error class: TypeError ERROR Error location: Unknown ERROR Traceback (most recent call last): File "/trace-model/trace-model.py", line 63, in <module> model_traced = neuron_lib.trace(model, example_inputs, compiler_workdir=f'{chip_type}-compiler-workdir') File "/opt/conda/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py", line 289, in trace neff_filename, metaneff, flattener, packer = _trace( File "/opt/conda/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py", line 357, in _trace neff_filename = hlo_compile(model_dir, compiler_workdir, compiler_args) File "/opt/conda/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py", line 249, in hlo_compile raise RuntimeError(f'neuronx-cc failed with {status}') RuntimeError: neuronx-cc failed with 1
@dneemuth Just to update, I saw the other issue filed (https://github.com/aws-neuron/aws-neuron-sdk/issues/732) and this is confirmed to be separate problem. We will keep the other ticket updated with progress
@jluntamazon thanks for following up. 100% on this, the new issue can be tracked separately via https://github.com/aws-neuron/aws-neuron-sdk/issues/732 . appreciate the efforts here & keep up the good work. cheers.
I am trying to compile a T5 model with torch-neuronx on inf2 instance, when I compile the model, I am getting the segmentation fault error
Here is my code
Here is a full error trace