google-deepmind / tapnet

Tracking Any Point (TAP)
https://deepmind-tapir.github.io/blogpost.html
Apache License 2.0
1.28k stars 120 forks source link

GPU JAX does not speed up the inference of TAP-Net? #26

Closed Wuziyi616 closed 1 year ago

Wuziyi616 commented 1 year ago

First of all, thank you so much for releasing this great work. I test it on my custom videos, and the tracking is really robust and accurate.

Following the README, I'm able to run the CPU version of the code (because the JAX in requirements.txt is the CPU-only version) at a very high speed. I use a 300-frame video and track 24 points from the initial frame. It takes only 10s to output the tracking results (excluding the video painting/saving time).

Then, I'm thinking of using the GPU version of JAX to further speed up the inference. I successfully installed JAX-cuda (see the screenshot below) and nvidia-smi confirms that the code is indeed using GPU (consumes 20GB memory on an RTX 3090 GPU). However, the running time is 15s -- much slower than JAX-CPU's 10s. For your reference, I'm using the command from the README:

python3 ./tapnet/experiment.py \
  --config=./tapnet/configs/tapnet_config.py \
  --jaxline_mode=eval_inference \
  --config.checkpoint_dir=./tapnet/checkpoint/ \
  --config.experiment_kwargs.config.inference.input_video_path=MY_VIDEO.mp4 \
  --config.experiment_kwargs.config.inference.output_video_path=result.mp4 \
  --config.experiment_kwargs.config.inference.resize_height=256 \
  --config.experiment_kwargs.config.inference.resize_width=256 \
  --config.experiment_kwargs.config.inference.num_points=24

I'm new to JAX, so I'd really appreciate it if you can provide some hints on why my GPU code runs slower then CPU. Thanks!

EDIT: after looking at some JAX-GPU related issues and the document, is it simply because the video/point size is too small? I.e. if I use a batch of video or more points, GPU should be faster?

image

cdoersch commented 1 year ago

I suspect the problem is that the first time you run any jit-compiled JAX function (or change the sizes/types of any arguments), it needs to recompile, and JAX compilation is very slow (and somewhat slower on GPU than CPU). Once it's compiled, the GPU version should be much faster.

Maybe try running the same number of points on a video of the same size multiple times in a loop, and see if subsequent iterations are faster? If you think TAPIR's compilation is too slow, I'd suggest submitting a bug to JAX itself to get some experienced people to look at the compilation. I'd happily +1 that bug :-)

Wuziyi616 commented 1 year ago

@cdoersch thank you so much for your detailed reply! Yes JIT makes perfect sense. I ran the same inference code with a 10-time for-loop. The first run is still slow, and the later 9 runs are much faster! Taking <0.2s for each run. Amazing!

Regarding the number of points (let's call it N), indeed, I understand the code has to be re-compiled every time the input shape changes. In my case, the number of points varies between videos, but the shape of the video is always the same. I look at the code of TAP-Net:

I'm thinking of padding the number of points to a fixed number, so that the input shape is always the same. Is that valid? I.e. is the tracking of one point independent from other points, so I can just pad zeros-points to the input, and get the same results for valid points as the results with no padding?

cdoersch commented 1 year ago

Yes, padding with zeros makes sense in this situation. For TAPIR, the results you get for a single query point are independent of the other query points in the batch.

Note that if the number of points is large, you probably want to take care with the query_chunk_size parameter. Smaller values will use less memory, but at the cost of a larger computational graph and therefore longer compile times. Typically you want to use the largest value that you can without running out of memory.

Wuziyi616 commented 1 year ago

Thanks a lot! Just wanna confirm, is the per-point result of TAP-Net also independent of the other query points? (according to the paper this is true)

Feel free to close the issue when you reply. I really appreciate your help.

cdoersch commented 1 year ago

Yes, it's also true for TAP-Net, although I'm curious why you would want to use TAP-Net now that TAPIR is out.

Wuziyi616 commented 1 year ago

Thanks, that's simply because according to Table 9 of the TAPIR paper, TAP-Net runs much faster than TAPIR, and my application requires 50~100Hz speed. Also, my videos are from some simple RL environments without big motions, severe occlusions, etc. That being said, I'll definitely try TAPIR later to see if it's fast enough : )