nickgkan / 3d_diffuser_actor

Code for the paper "3D Diffuser Actor: Policy Diffusion with 3D Scene Representations"
https://3d-diffuser-actor.github.io/
MIT License
198 stars 24 forks source link

Negative Losses in distributed training #47

Closed AnshShah3009 closed 4 weeks ago

AnshShah3009 commented 1 month ago

I'm encountering an issue where the loss values become negative during distributed training. This behavior makes it difficult to determine which checkpoint to use for the final model. Additionally, it's challenging to assess whether the model has converged since all loss metrics exhibit this behavior.

I suspect this issue might be related to incorrect gradient accumulation across distributed processes. Has anyone experienced similar issues or could offer guidance on how to address this problem? negtive losses distributed training

nickgkan commented 1 month ago

Hi, We never encountered this issue. In fact, all metrics are either absolute/square errors or accuracies. Which means, all of them are non-negative numbers by definition. Are you following a setup that we use with the released code, i.e. RLBench or CALVIN, or is it another setup?

AnshShah3009 commented 1 month ago

I haven't made any changes in the code for this particular run and I can't think of what could be the problem. The code seems to be correct too. I tried the same run on a different system, different gpus and I have a similar problem with that too.

nickgkan commented 1 month ago

Hi,

Were you eventually able to run the code successfully?

If not, could you please provide some more details on the setup? Which dataset/script are you running?

Besides, I can see that you are using wandb, so this is newly introduced code. Could there be a problem is loss logging?

AnshShah3009 commented 1 month ago

Hi,

I checked the logs from tensorboard and the same issues were reflected there, I debugged the whole thing and found out that there are some problems with the way information is gathered from different ranks.

my conda env: 3dda.txt

Will be porting the all_gather function in engine.py which uses dist.all_gather() to RPC (Remote Procedure Call) based object gathering.

Let me know if you have encountered something like this before.

nickgkan commented 4 weeks ago

Hi, we've never encountered such issues. You can try to run on a single GPU.

I still don't know which dataset and script you're running.

AnshShah3009 commented 4 weeks ago

Hi,

I found the exact place where the issue lies. It is related to the sync function in engine.py. This is a thread that discusses the same.

HARDWARE details: 2 x RTX A6000

This is an example which fails on my setup and I am unable to find a way around it:

import torch
from collections import defaultdict

def merge_and_concat_dicts(dicts):
    merged = defaultdict(list)  # Create a defaultdict of lists

    # Determine the device of the first tensor for each key
    reference_device = {key: tensor.device for key, tensor in dicts[0].items()}

    # Iterate through the list of dictionaries
    for d in dicts:
        for key, tensor in d.items():
            # Move tensor to the reference device before concatenating
            merged[key].append(tensor.to(reference_device[key]))

    # Concatenate tensors in the list for each key
    for key in merged:
        merged[key] = torch.cat(merged[key])

    return dict(merged)  # Convert back to a regular dictionary

# Example dictionaries to merge
dicts = [
    {'train-losses/mean/traj_action_mse': torch.tensor([0.7286, 0.6965, 0.7724, 0.6727, 0.7222], device='cuda:0'),
     'train-losses/mean/traj_pos_l2': torch.tensor([0.4546, 0.4149, 0.4043, 0.4333, 0.4338], device='cuda:0'),
     # More keys and tensors...
    },
    {'train-losses/mean/traj_action_mse': torch.tensor([0.7014, 0.6866, 0.7255, 0.6587, 0.7330], device='cuda:1'),
     'train-losses/mean/traj_pos_l2': torch.tensor([0.4265, 0.4366, 0.3879, 0.4117, 0.4058], device='cuda:1'),
     # More keys and tensors...
    }
]

# Use the function to merge and concatenate
merged_dict = merge_and_concat_dicts(dicts)
print(merged_dict)

My setup's output:

{'train-losses/mean/traj_action_mse': tensor([0.7286, 0.6965, 0.7724, 0.6727, 0.7222, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0'),
 'train-losses/mean/traj_pos_l2': tensor([0.4546, 0.4149, 0.4043, 0.4333, 0.4338, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')}

FIX:

Instead of directly transfering tensor between gpus transfer the data onto cpu first:

def synchronize_between_processes(self, a_dict):
        all_dicts = all_gather(a_dict)

        if not is_dist_avail_and_initialized() or dist.get_rank() == 0:
            merged = {}
            for key in all_dicts[0].keys():
                device = all_dicts[0][key].device
                merged[key] = torch.cat([
                    p[key].to('cpu').to(device) for p in all_dicts     ## small change here
                    if key in p
                ])
            a_dict = merged
        return a_dict