aws-neuron / aws-neuron-samples

Example code for AWS Neuron SDK developers building inference and training applications
Other
101 stars 32 forks source link

Using torch_neuronx models for Causal Language Models #32

Closed Bhuvanesh09 closed 9 months ago

Bhuvanesh09 commented 9 months ago

The sample code for GPT2 at https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/inference/hf_pretrained_gpt2_feature_extraction_on_trn1.ipynb recommends that we pad the input before passing to the forward.

torch_neuronx.trace() expects a tensor or tuple of tensor inputs to use for tracing, so we unpack the tokenzier output. Additionally, the input shape that's used duing compilation must match the input shape that's used during inference. To handle this, we pad the inputs to the maximum size that we will see during inference.

But it has been observed that padding to the right for Causal Models leads to inaccurate results as can be seen here: https://github.com/huggingface/transformers/issues/14521#issuecomment-990895170

Additionally, torch_neuronx supports dynamic input only along its first (batch dimension). Whereas for any Causal LM, the length of the input rises along the sequence dimension after sampling in each subsequent forward pass.

Is there any recommended way/suggestions on how torch_neuronx can be used for Causal Language Models?

hannanjgaws commented 9 months ago

Hi @Bhuvanesh09:

In general, causal generative decoding requires LLM-specific support to be performant. We have added optimized support for several causal language models in the transformers-neuronx library, including GPT2, LLaMA, and BLOOM. To get started with this library, please see our transformers-neuronx documentation.

RE: It has been observed that padding to the right for Causal Models leads to inaccurate results:

The transformers-neuronx library expects that inputs are left padded (right-aligned) for batched inputs of varying lengths.

RE: Whereas for any Causal LM, the length of the input rises along the sequence dimension after sampling in each subsequent forward pass:

Transformers-neuronx supports autoregressive sampling via bucketing, which handles dynamic input and output sequences lengths. For more information about bucketing, please refer to this documentation about LLM inference on Neuron.

I'll close this ticket, but please feel free to open a new one if you encounter any issues using transformers-neuronx or torch-neuronx.