google-research / vision_transformer

Apache License 2.0
10.18k stars 1.27k forks source link

Inference on LiT Model is slow? #263

Closed famishedrover closed 1 year ago

famishedrover commented 1 year ago

I'm using jaxlib with cudnn to run the LiT Model. The inference time for a single image is around 5 seconds on Quattro 8000 48GB. I'm also new to jax so maybe I'm missing something there.

On CPU (I just removed jaxlib cudnn to test this) the times are much faster to be around 2 seconds. I found that the bulk of this time is because of the ViT model. Pytorch ViT models are extremely fast (some miliseconds maybe).

Questions :

  1. Is there a PyTorch version of the trained LiT model?
  2. If not, then can someone help me with running LiT model faster on jax. I suppose it has to be faster otherwise training LiT would have been impossible with speeds that I'm getting.

Configuration :

jax==0.4.7 jaxlib==0.4.7+cuda12.cudnn88 vit-jax @ git+https://github.com/google-research/vision_transformer@85c4f53febd929c43e70e8ff598f9f00d52948b7

CUDA 11.4

I have also tried with CUDA 12.1, but similar results.

Thanks!

andsteing commented 1 year ago

What do you see when you say jax.devices() ?

Can you run a simple speed test?

x = jax.random.normal(jax.random.PRNGKey(0), [8_000, 8_000])
%timeit x@x
with jax.default_device(jax.devices('cpu')[0]):
  %timeit x@x

On a T4, I get ~52.8 microseconds for the first %timeit, and around 16 seconds for the second, where I instruct JAX not to use the GPU.

When you're seeing slow evaluations with the LiT model it's most probably due to the GPU not being properly used by JAX, rather than something specific to the LiT model implementation.

famishedrover commented 1 year ago

Here are my results.

94.9 ms ± 184 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
434 ms ± 2.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

I was able to improve the inference speed using these two methods :

  1. Force jax to use GPU. From what I could find, jax should use GPU by default, if available. This command was helpful.

    with jax.default_device(jax.devices('gpu')[0]):
  2. Jit the function properly. I found that jax compilies binaries based on the dummy input batch size as well. Once I compiled using a fixed batch size and used the same batch size for further inference it was much faster.

Now I'm getting : 84.8 ms ± 1.12 ms for batch_size = 1 332 ms for batch_size = 50

Thanks!