aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
420 stars 136 forks source link

Internal tensorizer error on RWKV model #878

Open pm-mck opened 2 months ago

pm-mck commented 2 months ago

Hello,

I am working on tracing RWKV using neuronx and I received the following error:

Error

[TEN404] (_dynamic-update-slice.5283) Internal tensorizer error - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new

Version Info

2024-04-29T01:35:48Z Diagnostic information:
2024-04-29T01:35:48Z   NeuronX Compiler version 2.13.68.0+6dfecc895
2024-04-29T01:35:48Z   
2024-04-29T01:35:48Z   Python version 3.10.12
2024-04-29T01:35:48Z   HWM version 2.13.68.0+6dfecc895
2024-04-29T01:35:48Z   NumPy version 1.25.2
2024-04-29T01:35:48Z   
2024-04-29T01:35:48Z   Running on AMI ami-086b8d42c7e9f91d7
2024-04-29T01:35:48Z   Running in region use2-az2

The code I'm using to trace is pretty basic. Maybe it's too basic since it's reporting:

/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:144: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=0, shape=torch.Size([39]), dtype=torch.int64)
  warnings.warn(

Here's the code, it's derived from the RWKV chat example:

Code

import torch_xla.core.xla_model as xm

import os, copy, types, gc, sys, re
import numpy as np
from prompt_toolkit import prompt
import torch
import torch_neuronx

os.environ["RWKV_CUDA_ON"] = "0"  # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries

from rwkv.model import RWKV
from rwkv.utils import PIPELINE

########################################################################################################

args = types.SimpleNamespace()

args.strategy = "xla:0 fp32"  # use CUDA, fp16
#args.strategy = 'cpu fp32'
args.MODEL_NAME = "/home/ubuntu/RWKV-x060-World-3B-v2.1-20240417-ctx4096"

GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.0
GEN_alpha_frequency = 1.0
GEN_penalty_decay = 0.996

CHUNK_LEN = 256  # split input into chunks to save VRAM (shorter -> slower, but saves VRAM)

########################################################################################################

class ModelWrap():
    model_state = None
    model = None

    def __init__(self, model):
        self.model = model

    def forward(self, toks):
        out, model_state = self.model.forward(toks, self.model_state)
        self.model_state = model_state
        return out, model_state

    def __call__(self, t):
        return self.forward(t)

print(f"Loading model - {args.MODEL_NAME}")
model_rwkv = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
pipeline = PIPELINE(model_rwkv, "rwkv_vocab_v20230424")
model = ModelWrap(model_rwkv)

model_tokens = []
model_state = None
ctx = "User: hi" + "\n\n"
ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n"
ctx = ctx.replace("\r\n", "\n")

tokens = pipeline.encode(ctx)
tokens = [int(x) for x in tokens]
model_tokens += tokens

t = torch.tensor(tokens[:CHUNK_LEN], device='cpu')
model_neuron = torch_neuronx.trace(model, t)
torch.jit.save(model_neuron, 'model.pt')

Any help is appreciated!

jyang-aws commented 2 months ago

Hi @pm-mck ,

Thanks for reporting the issue. I'm trying to reproduce it with our latest release. But I hit the issue below

Loading model - /home/ubuntu/RWKV-x060-World-3B-v2.1-20240417-ctx4096
Traceback (most recent call last):
  File "gh.py", line 49, in <module>
    model_rwkv = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
  File "/shared/on-call/env/lib/python3.8/site-packages/torch/jit/_script.py", line 303, in init_then_script
    original_init(self, *args, **kwargs)
  File "/shared/on-call/env/lib/python3.8/site-packages/rwkv/model.py", line 186, in __init__
    raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/")
ValueError: Invalid strategy. Please read https://pypi.org/project/rwkv/

Did I miss something?

pm-mck commented 2 months ago

Hi @jyang-aws - thank you for your response. Yes, I had to modify the RWKV import itself to allow xla tensors. Nothing major, but it has a regex check to validate strategies and I also forced it to use cpu for indexed tensors. I can send you a patch if that's helpful.

jyang-aws commented 1 month ago

thanks. I'm able to reproduce the issue now. will fix from our end and keep you updated.

PrateekAg1511 commented 3 weeks ago

@jyang-aws I am facing the exact same issue! Is there a fix for it ?