kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

Inference speed on TPU #115

Closed gamcoh closed 3 years ago

gamcoh commented 3 years ago

Hi everyone,

I fine tuned the model on my custom data and now I want to serve it, here's what I did:

start = time.time()

GPT-J 6B config

config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B") config.attention_layers = ["global"] 28 config.attention_types = [["global"], 28] config.num_layers = 28 config.num_heads = 16 config.hidden_size = 256 config.num_heads config.vocab_size = 50400 config.rotary = True config.rotary_dim = 64 config.jax = True

Load the model

start = time.time() model = GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=Checkpoint("./email-copilot-hf")) print(f'Loaded model {time.time() - start}') tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")

while True: all_options = [] all_q = [] while len(all_options) < 8: try: o, q = requests_queue.get(block=False) all_options.append(o) all_q.append(q) except Empty: if len(all_options): break else: time.sleep(0.01)

start = time.time()

for i in range(len(all_options)):
    ctx = all_options[i]
    q = all_q[i]
    try:
        print(ctx)
        tokens = tokenizer.encode(ctx['context'], return_tensors="pt")
        output = model.generate(tokens,
                    top_p=ctx['top_p'],
                    top_k=ctx['top_k'],
                    temperature=ctx['temp'],
                    max_length=ctx['length'],
                    do_sample=ctx['sample'],
                    use_cache=ctx['cache'])
        q.put(tokenizer.decode(output[0], skip_special_tokens=False))
    except Exception as e:
        print(e)

print(f"completion done in {time.time() - start:06}s")

And it works great! but it's slow...
Here's some benchmark:

top_k=100 top_p=0.9 temp=0.9 max_length=20 cache=True do_sample=True

completion done in 6.864500999450684s



This is a development server running on the google TPU recommended in the `howto_finetune.md` file.

Is there a solution for the API or the model to run faster without some post processing?

Thanks!
kingoflolz commented 3 years ago

Nothing looks particularly wrong, running inference on large models is just inherently quite slow. Feel free to do some profiling and let me know if anything looks out of place.