kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

GPT-J inference on TPU #219

Open airesearch38 opened 2 years ago

airesearch38 commented 2 years ago

Is it possible to use a TPU for inference?

The guys at NLPCloud.io told me that's what they're doing, but I have no idea how they're doing it... First I don't know how to support advanced parameters like end_sequence (so the model stops generating when it reaches a specific token) or repetition penalty (see the Hugging face parameters for text generation). Secondly, the TPU IPs seem to rotate on a regular basis and there's nothing you can do about it. So not sure how to use a TPU for inference through a REST API...

Thanks in advance!

leejason commented 2 years ago

You may consider running "device_serve.py" on TPU and the "streamlit" approach in the following.

https://github.com/vicgalle/gpt-j-api

airesearch38 commented 2 years ago

Interesting, thanks for the suggestion!

If I understand correctly the code, stop_sequence is not stopping the model generation but simply splitting the result once the model finishes generating:

if stop_sequence is not None and stop_sequence in text:
        text = text.split(stop_sequence)[0] + stop_sequence

So generation takes the same time whether the stop_sequence token is reached or not. Am I correct?

And I don't see a way to handle the fact that TPU IPs are regularly changing...

leejason commented 2 years ago

I was trying streamlit as a quick web app for testing model inference and found it convenient. Indeed, the floating IP of TPU is another issue. As for stop_sequence, I have no comment because I haven't encountered any issue with it yet. In brief, "device_serve.py" works on TPU. It could be a starting point.