henry123-boy / SpaTracker

[CVPR 2024 Highlight] Official PyTorch implementation of SpatialTracker: Tracking Any 2D Pixels in 3D Space
Other
725 stars 25 forks source link

Out of memory #21

Open tungyen opened 5 months ago

tungyen commented 5 months ago

Hi, I can run the demo code of your model on Nvidia RTX-3090. But I only succeed on the video with length 3 seconds, but failed with video of length 9 seconds. Is there any way to optimize for this issue without changing GPU? Thank you very much.

m43 commented 3 months ago

Perhaps try running with half precision? E.g., see PyTorch docs here or use PyTorch Lightning to wrap your code and use their out-of-the-box mixed precision flags.

If you are using the mono depth estimators to obtain the depths, the memory bottleneck can be moved there for longer videos. You can try to batch the call to the depth estimator as in the following snippet which predicts the depths 10 frames a time:

with torch.no_grad():
    batch_size = 10
    if sample.video[0].shape[0] > batch_size:
        vidDepths = []
        for i in range(sample.video[0].shape[0] // batch_size + 1):
            if (i + 1) * batch_size > sample.video[0].shape[0]:
                end_idx = sample.video[0].shape[0]
            else:
                end_idx = (i + 1) * batch_size
            if i * batch_size == end_idx:
                break
            video = sample.video[0][i * batch_size:end_idx]
            vidDepths.append(depth_predictor.infer(video / 255))

        videodepth = torch.cat(vidDepths, dim=0)
    else:
        videodepth = depth_predictor.infer(sample.video[0] / 255)
args.depth_near = 0.01
args.depth_far = 65.0
depths = videodepth.clamp(args.depth_near, args.depth_far)

Also, make sure that gradients are not computed using the torch.no_grad() context manager, for example.