zkonduit / ezkl

ezkl is an engine for doing inference for deep learning models and other computational graphs in a zk-snark (ZKML). Use it from Python, Javascript, or the command line.
https://docs.ezkl.xyz/
917 stars 125 forks source link

GPTNeoXForCausalLM models shape issue #634

Closed tobinsouth closed 9 months ago

tobinsouth commented 9 months ago

Generate settings throws an error when using any Huggingface GPTNeoXForCausalLM model (interestingly it doesn't for GPTNeoForCausalLM models).

TL;DR:

thread 'main' panicked at src/tensor/mod.rs:968:13:
assertion failed: shape.contains(d) || *d == 1

Device and Operating System

ezkl 5.0.8 transformers 4.35 Ubuntu 22.04 40 core Intel(R) Xeon(R) CPU E5-2687W v3 @ 3.10GHz 1005 GB Memory

Replication code

import torch, json
import torch.nn as nn

from transformers import GPTNeoXForCausalLM
pretrained_model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-70m-deduped",
  revision="step3000",
  cache_dir="./pythia-70m-deduped/step3000",
)

# Wrap the model to get the logits, otherwise it just returns a CausalLMOutputWithPast object
class GPTWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = pretrained_model
    def forward(self, input_ids):
        return self.model(input_ids).logits

model = GPTWrapper()

x = torch.randint(0,100,(1,8)).long()
output = model(x)

# Export the model
torch.onnx.export(model, x, "GPTNeoXForCausalLM.onnx", export_params=True, input_names = ['input'],  output_names = ['output'], dynamic_axes={'input' : {0 : 'batch_size'},   'output' : {0 : 'batch_size'}})
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(input_data = [data_array], output_data = [((o).detach().numpy()).reshape([-1]).tolist() for o in output])
json.dump( data, open("GPTNeoXForCausalLM.json", 'w' ) )

# Run gen settings to find the error
import ezkl
ezkl.gen_settings("GPTNeoXForCausalLM.onnx", "test_settings.json")

# Or try with the command line
# ezkl gen-settings -M GPTNeoXForCausalLM.onnx --settings-path=test_settings.json

# Error is
# thread 'main' panicked at src/tensor/mod.rs:968:13:
# assertion failed: shape.contains(d) || *d == 1
tobinsouth commented 9 months ago

It's worth noting that when running models we want to use efficient ones, and it seems like GPTNeoX models have fewer FLOPs per param than GPTNeo models. Plus, many of the best open-source LLMs build off the NeoX template.

tobinsouth commented 9 months ago

Added info: The same error occurs for other LM architectures

e.g. LlamaForCausalLM

from transformers import AutoModelForCausalLM, AutoTokenizer
pretrained_model = AutoModelForCausalLM.from_pretrained('JackFram/llama-160m')
tokenizer = AutoTokenizer.from_pretrained("JackFram/llama-160m")

P.S. you might want to clear your cache if you download too many models: rm -r ~/.cache/huggingface/hub/*

alexander-camuto commented 9 months ago

tracked here: https://github.com/sonos/tract/issues/1269