152334H / tortoise-tts-fast

Fast TorToiSe inference (5x or your money back!)
GNU Affero General Public License v3.0
771 stars 179 forks source link

GPT optimisation with TensorRT/FasterTransformer/Triton/??? #3

Open 152334H opened 1 year ago

152334H commented 1 year ago

This issue is likely to take a substantial amount of effort to solve.

Primary problem: GPT2InferenceModel is uniquely subclassed from 🤗's transformers.GPT2Model. The architecture is significantly different from a basic GPT2, and a substantial amount of code needs to be written to make an optimized version of the model work with the usual huggingface generation function.

pepinu commented 1 year ago

I've tried using https://github.com/NVIDIA-AI-IOT/torch2trt but to no avail, too many ops not implemented there:

Warning: Encountered known unsupported method torch.arange
Warning: Encountered known unsupported method torch.zeros
Warning: Encountered known unsupported method torch.addmm
[02/06/2023-19:18:56] [TRT] [E] 4: model.h.0.attn:1:SLICE:GPU: mismatch in number of dimensions for start.
[02/06/2023-19:18:56] [TRT] [E] 4: model.h.0.attn:1:SLICE:GPU: mismatch in number of dimensions for start.
[02/06/2023-19:18:56] [TRT] [E] 4: model.h.0.attn:1:SLICE:GPU: mismatch in number of dimensions for start.
[02/06/2023-19:18:56] [TRT] [E] 4: model.h.0.attn:1:SLICE:GPU: mismatch in number of dimensions for start.

I've however have been successful in using jit torchscript, but there is a big discrepancy in results

/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py:1001: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!

Mismatched elements: 524072 / 547840 (95.7%)
Greatest absolute difference: 0.011121749877929688 at index (0, 502, 440) (up to 1e-05 allowed)
Greatest relative difference: 416.0444345892299 at index (0, 78, 657) (up to 1e-05 allowed)

We won't be able to convert whole UnifiedVoice class to tensorRT however, because there is a dependency on long (i64) when calculating cross entropy loss, which is nonnegotiable within pytorch while TensorRT does not support it.

The onnx conversion worked like a charm and I've looked through the hugging face link you referred to somewhere - I will hack something up to maybe get the GPT2 part of the model on TRT and see how it can be connected.

Ryu1845 commented 1 year ago

@pepinu have you made any progress on this?

Ryu1845 commented 1 year ago

Might be interesting https://onnxruntime.ai/docs/performance/tune-performance.html

pepinu commented 1 year ago

https://github.com/152334H/tortoise-tts-fast/issues/3#issuecomment-1425568785 sorry I've got sick last week, will get back to working on this tomorrow 😞

slightly off topic:

I've also got an idea to use a pretrained distilled gpt2 instead of the one we have now, however I am not sure how to fine-tune it. I know that not all training for tortoise was opened, however maybe the part for autoregressive is?

link to distilled gpt2: https://huggingface.co/distilgpt2

here's the issue: https://github.com/neonbjb/tortoise-tts/issues/318

I've run a test and the distilled checkpoint runs twice as fast on some random sequence (for test generation)

Screenshot 2023-02-14 at 21 45 03
DrewScatterday commented 1 year ago

I'm a little late to this thread @pepinu did you ever make any more progress with this?

I was able to use deepspeed on linux WSL to get some slight speedups with inference on fast tortoise. My next goal was figuring out if it would be possible to split tortoise inference between multiple GPUs (with something like Triton) for even faster speedups

152334H commented 1 year ago

that seems unlikely to work.... are you sure that flops are the bottleneck vs mem bandwidth, and that splitting wouldn't reduce speed via comm overhead?

DrewScatterday commented 1 year ago

Yeah thats a good point.. was talking with someone else about it and they said tortoise might not be large enough for multi gpus to really help

DrewScatterday commented 1 year ago

Do you have any other recommendations for things to try? @152334H

with really low quality inference parameters and deepspeed I'm able to get 23 seconds of audio (345 characters) in roughly 14 seconds on a 3070TI 8GB of vram which I'm pretty happy with. I was thinking on something like a 3090 or v100 I could get this even lower like hopefully 10 seconds for 23 seconds of audio

DrewScatterday commented 1 year ago

Deepspeed's init inference function (which I've added to autoregressive.py) seems to allow you to do tensor parallelism but I'll have to do some tinkering to see if that actually helps tortoise or not

152334H commented 1 year ago

can't help much. anything that increases the generation speed of a gpt-2-m model in general will speed up this situation. but doing that is honestly pretty hard.

(with multigpu, couldn't you just do batch parallelism, though?)

DrewScatterday commented 1 year ago

I will report back to let you know what I figure out with multi gpus