btsmart / splatt3r

Official repository for Splatt3R: Zero-shot Gaussian Splatting from Uncalibrated Image Pairs
Other
544 stars 21 forks source link

The GPU memory usage continues to increase as the number of epochs increases #18

Open mxuai opened 2 months ago

mxuai commented 2 months ago

Weixin Screenshot_20240905135527 Thank you for open-sourcing such an impressive work. However, when I tried running main.py with 4 RTX 3090 GPUs under the default settings, I encountered an NCCL error during the GPU parallel part: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL. Although I haven't fully identified the source of the error, I suspect it might be an issue with my workstation, so I switched to running your training program on a single GPU. As shown in the image, I observed that GPU usage increases with each epoch, eventually leading to an out-of-memory error. Do you have any ideas on where a memory leak might occur in your code?

btsmart commented 2 months ago

Hello,

I suspect this issue relates to the fact that the Gaussian Splatting renderer we use can use different amounts of memory depending on the Gaussians being rendered. In particular, we found that if we do not correctly create loss masks for our samples during training, the Gaussians will tend to grow larger as training progresses, which causes the memory requirements to increase during training. I believe this is because more Gaussians end up in each rasterization tile used by the renderer. If you visualize your samples/loss masks and everything seems normal, please let me know.

Regarding the GPU Parallel NCCL error, I used to have a few of those that were related to DDP training with Pytorch Lightning, and deadlock between the threads. I believe I fixed all the issues, but it might be a good place to start looking?

mxuai commented 2 months ago

Hi, Thank you for your quick reply. Although I haven't found the exact cause of the increasing memory issue, I do find that adding 'torch.cuda.empty_cache()' in val_step() can solve this issue. As for the DDP training issue, thank you very much for your suggestion, I will try to figure it out.

btsmart commented 2 months ago

Interesting, thanks for letting me know. I am not sure why val_step() would be causing a memory leak, but I'll have a look and let you know if I find anything

mxuai commented 2 months ago

Hi, To avoid reporting potentially misleading information, I tried several times again and found that even when emptying the GPU cache during the validation phase, there are still occasional instances (as shown in Fig 1) where the GPU memory usage increases. Following your suggestion to check the loss mask, I found that most of the masks are normal, but a very small portion of them are indeed incorrect. It seems that this issue is difficult to completely avoid. Weixin Screenshot_20240910162107

Additionally, there's a minor issue that I think it's better to share with you. I noticed that with the default learning rate schedule, there were instances of unstable training, where the loss would rapidly increase, causing the model to crash (as shown in Fig 2), and all rendered images would turn completely black. However, after adjusting the learning rate scheduler to start decaying from 1/5 of the total epochs, the issue was resolved. Weixin Screenshot_20240910162142

mxuai commented 2 months ago

By the way, I'm very interested in the method you use to generate the coverage file, but it seems like the preprocessing part isn't provided in your code. Could I ask in this issue how you calculate the overlap between each pair? If you think it's inappropriate to ask this question in this issue, I'll delete this comment and open a new issue instead. Thanks.

btsmart commented 2 months ago

Thanks for sharing your results with the modified learning rate schedule. If I find any more information about the memory usage or the incorrect masks, I will let you know.

Regarding the coverage files, we use the same procedure as with the loss masks to generate a mask showing which pixels in the second image are visible in the first (with ones representing pixels with correspondences and zeroes representing those without), then the 'overlap' between the pair of images is given as the mean of that tensor.

I believe I used the following code to calculate this, although this was written for a much older version of the code and therefore might not work without modifications. It was also written to be run across 4 GPUs, so you may need to remove the data.sequences = data.sequences[int(gpu)::4] line.

from data.scannetpp.scannetpp import ScanNetPPData
import utils.loss_mask as loss_mask
import torch
import json
import torchvision
import os
import sys

@torch.no_grad()
def calculate_loss_mask(targets, context):
    '''Calcuate the loss mask for the target views in the batch'''

    target_depth = torch.stack([target_view['depthmap'] for target_view in targets], dim=1)
    target_intrinsics = torch.stack([target_view['camera_intrinsics'] for target_view in targets], dim=1)
    target_c2w = torch.stack([target_view['camera_pose'] for target_view in targets], dim=1)
    context_depth = torch.stack([context_view['depthmap'] for context_view in context], dim=1)
    context_intrinsics = torch.stack([context_view['camera_intrinsics'] for context_view in context], dim=1)
    context_c2w = torch.stack([context_view['camera_pose'] for context_view in context], dim=1)

    target_intrinsics = target_intrinsics[..., :3, :3]
    context_intrinsics = context_intrinsics[..., :3, :3]

    mask = loss_mask.calculate_in_frustum_mask(
        target_depth, target_intrinsics, target_c2w,
        context_depth, context_intrinsics, context_c2w
    )
    return mask

if __name__ == '__main__':

    root = DATA_ROOT_HERE
    data = ScanNetPPData(root, 'val')
    resolution = (512, 512)

    gpu = sys.argv[1]
    device = torch.device(f'cuda:{gpu}')

    org_transform = torchvision.transforms.ToTensor()

    data.sequences = data.sequences[int(gpu)::4]

    for sequence_no, sequence in enumerate(data.sequences):

        output = {}
        print(f'Processing sequence {sequence_no + 1}/{len(data.sequences)}')
        output[sequence] = {}
        color_paths = data.color_paths[sequence]
        views = [data.get_view(sequence, i, resolution) for i in range(len(color_paths))]

        for view in views:
            view['original_img'] = org_transform(view['original_img']).to(device)
            view['depthmap'] = torch.tensor(view['depthmap']).unsqueeze(0).to(device)
            view['valid_mask'] = view['depthmap'] > 1e-6
            view['camera_intrinsics'] = torch.tensor(view['camera_intrinsics']).unsqueeze(0).to(device)
            view['camera_pose'] = torch.tensor(view['camera_pose']).unsqueeze(0).to(device)

        coverage_vals = []
        for i, color_path1 in enumerate(color_paths):

            print(f'Processing view {i + 1}/{len(color_paths)} of sequence {sequence_no + 1}/{len(data.sequences)}')

            coverage = []
            for batch in range(0, len(views), 100):
                masks = calculate_loss_mask(views[batch:min(batch+100,len(views))], [views[i]]).float()
                coverage.append(masks[0].mean(dim=-1).mean(dim=-1))
            coverage = torch.cat(coverage)
            assert coverage.shape[0] == len(views)
            coverage_vals.append(coverage)

        output[sequence] = torch.stack(coverage_vals).cpu().numpy().tolist()

        # Save the dictionary as a JSON file
        os.makedirs('coverage', exist_ok=True)
        output_path = os.path.join('coverage', f'{sequence}.json')
        with open(output_path, 'w') as f:
            json.dump(output, f)

The test coverage sets were then produced using these calculated values:

import json
import random
import tqdm

# Add MAST3R and PixelSplat to the sys.path to prevent issues during importing
import sys
sys.path.append('src/pixelsplat_src')
sys.path.append('src/mast3r_src')
sys.path.append('src/mast3r_src/dust3r')

def generate_test_set(coverage, alpha, beta, save_file, max_samples_per_sequence=1000, max_samples_for_view=100):

    test_set = []
    for i, sequence in enumerate(coverage):

        samples_for_sequence = []

        number_of_views = len(coverage[sequence])

        print(f'Sequence {i+1}/{len(coverage)}: {sequence}')
        for c_view_1 in tqdm.tqdm(range(number_of_views)):

            samples_for_view = []

            for c_view_2 in range(number_of_views):

                overlap = coverage[sequence][c_view_1][c_view_2]
                if c_view_1 == c_view_2 or overlap < alpha:
                    continue

                for target_view in range(number_of_views):

                    target_overlap = max(coverage[sequence][c_view_1][target_view], coverage[sequence][c_view_2][target_view])
                    if target_view == c_view_1 or target_view == c_view_2 or target_overlap < beta:
                        continue

                    samples_for_view.append((sequence, c_view_1, c_view_2, target_view))

            samples_per_view = random.sample(samples_for_view, min(max_samples_for_view, len(samples_for_view)))
            samples_for_sequence = samples_for_sequence + samples_per_view

        samples_for_sequence = random.sample(samples_for_sequence, min(max_samples_per_sequence, len(samples_for_sequence)))
        test_set = test_set + samples_for_sequence

    with open(save_file, 'w') as f:
        json.dump(test_set, f)

    return test_set

if __name__ == '__main__':

    from data.scannetpp.scannetpp import ScanNetPPData

    root = '/media/brandon/anubis09/scannetpp'
    data = ScanNetPPData(root, 'test')
    coverage = {}
    for sequence in data.sequences:
        with open(f'./data/scannetpp/coverage/{sequence}.json', 'r') as f:
            sequence_coverage = json.load(f)
        coverage[sequence] = sequence_coverage[sequence]

    for alpha, beta in [(0.9, 0.9), (0.7, 0.7), (0.5, 0.5), (0.3, 0.3)]:
        generate_test_set(coverage, alpha, beta, f'test_set_{alpha}_{beta}.json')

If you run into any issues please let me know!