togethercomputer / stripedhyena

Repository for StripedHyena, a state-of-the-art beyond Transformer architecture
Apache License 2.0
299 stars 21 forks source link

evo and stripedhyena crash the server when doing a simple inference #21

Open sun-qibo opened 2 months ago

sun-qibo commented 2 months ago

I observed similar behavior in evo and tripedhyena, that the model can be loaded successfully but the server crashed once I tried to do a simple inference.

compute ressource: databricks azure cluster with nvidia A100

pkgs: flash-fft-conv and flash-attention well installed so that the model had no problem being loaded

code for stripedhyena:

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextStreamer

model_name = "togethercomputer/StripedHyena-Hessian-7B"
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    model_max_length=sys.maxsize,
    trust_remote_code=True,
)

tokenizer.pad_token = tokenizer.eos_token

config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.use_cache = True

device = torch.device("cuda")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    trust_remote_code=True,
).to(device)

input_text = "Question: How many hours in one day? Answer: "
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

model.generate(input_ids)

code for evo:

from transformers import AutoConfig, AutoModelForCausalLM
from stripedhyena.tokenizer import CharLevelTokenizer

tokenizer = CharLevelTokenizer(512)

hf_model_name = 'togethercomputer/evo-1-131k-base'

model_config = AutoConfig.from_pretrained(
    hf_model_name,
    trust_remote_code=True,
    revision='1.1_fix',
)

model_config.use_cache = True

model = AutoModelForCausalLM.from_pretrained(
        hf_model_name,
        config=model_config,
        trust_remote_code=True,
        revision='1.1_fix',
    )

sequence = 'ACGT'

input_ids = torch.tensor(
    tokenizer.tokenize(sequence),
    dtype=torch.int,
).to(device).unsqueeze(0)

with torch.no_grad():
    logits, _ = model(input_ids) 

Both codes crashed the python kernel at the last line.

I was not sure if the issue was caused by configuration of my databrick ressources I also tried randomly other models of the same size e.g. "HuggingFaceH4/zephyr-7b-beta" and there were no problem making the inference. I do not know if there's any other possible incompatibility between stripedhyena and databricks though.

Does anyone also encounter this problem?