aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
420 stars 136 forks source link

Internal tensorizer error on RWKV model #878

Open pm-mck opened 2 months ago

pm-mck commented 2 months ago


I am working on tracing RWKV using neuronx and I received the following error:


[TEN404] (_dynamic-update-slice.5283) Internal tensorizer error - Please open a support ticket at

Version Info

2024-04-29T01:35:48Z Diagnostic information:
2024-04-29T01:35:48Z   NeuronX Compiler version
2024-04-29T01:35:48Z   Python version 3.10.12
2024-04-29T01:35:48Z   HWM version
2024-04-29T01:35:48Z   NumPy version 1.25.2
2024-04-29T01:35:48Z   Running on AMI ami-086b8d42c7e9f91d7
2024-04-29T01:35:48Z   Running in region use2-az2

The code I'm using to trace is pretty basic. Maybe it's too basic since it's reporting:

/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch_neuronx/xla_impl/ UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=0, shape=torch.Size([39]), dtype=torch.int64)

Here's the code, it's derived from the RWKV chat example:


import torch_xla.core.xla_model as xm

import os, copy, types, gc, sys, re
import numpy as np
from prompt_toolkit import prompt
import torch
import torch_neuronx

os.environ["RWKV_CUDA_ON"] = "0"  # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries

from rwkv.model import RWKV
from rwkv.utils import PIPELINE


args = types.SimpleNamespace()

args.strategy = "xla:0 fp32"  # use CUDA, fp16
#args.strategy = 'cpu fp32'
args.MODEL_NAME = "/home/ubuntu/RWKV-x060-World-3B-v2.1-20240417-ctx4096"

GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.0
GEN_alpha_frequency = 1.0
GEN_penalty_decay = 0.996

CHUNK_LEN = 256  # split input into chunks to save VRAM (shorter -> slower, but saves VRAM)


class ModelWrap():
    model_state = None
    model = None

    def __init__(self, model):
        self.model = model

    def forward(self, toks):
        out, model_state = self.model.forward(toks, self.model_state)
        self.model_state = model_state
        return out, model_state

    def __call__(self, t):
        return self.forward(t)

print(f"Loading model - {args.MODEL_NAME}")
model_rwkv = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
pipeline = PIPELINE(model_rwkv, "rwkv_vocab_v20230424")
model = ModelWrap(model_rwkv)

model_tokens = []
model_state = None
ctx = "User: hi" + "\n\n"
ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n"
ctx = ctx.replace("\r\n", "\n")

tokens = pipeline.encode(ctx)
tokens = [int(x) for x in tokens]
model_tokens += tokens

t = torch.tensor(tokens[:CHUNK_LEN], device='cpu')
model_neuron = torch_neuronx.trace(model, t), '')

Any help is appreciated!

jyang-aws commented 2 months ago

Hi @pm-mck ,

Thanks for reporting the issue. I'm trying to reproduce it with our latest release. But I hit the issue below

Loading model - /home/ubuntu/RWKV-x060-World-3B-v2.1-20240417-ctx4096
Traceback (most recent call last):
  File "", line 49, in <module>
    model_rwkv = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
  File "/shared/on-call/env/lib/python3.8/site-packages/torch/jit/", line 303, in init_then_script
    original_init(self, *args, **kwargs)
  File "/shared/on-call/env/lib/python3.8/site-packages/rwkv/", line 186, in __init__
    raise ValueError("Invalid strategy. Please read")
ValueError: Invalid strategy. Please read

Did I miss something?

pm-mck commented 2 months ago

Hi @jyang-aws - thank you for your response. Yes, I had to modify the RWKV import itself to allow xla tensors. Nothing major, but it has a regex check to validate strategies and I also forced it to use cpu for indexed tensors. I can send you a patch if that's helpful.

jyang-aws commented 1 month ago

thanks. I'm able to reproduce the issue now. will fix from our end and keep you updated.

PrateekAg1511 commented 3 weeks ago

@jyang-aws I am facing the exact same issue! Is there a fix for it ?