Luodian / Otter

🦦 Otter, a multi-modal model based on OpenFlamingo (open-sourced version of DeepMind's Flamingo), trained on MIMIC-IT and showcasing improved instruction-following and in-context learning ability.
https://otter-ntu.github.io/
MIT License
3.52k stars 241 forks source link

Question about missed values in `batch_mimicit`, maybe because of using `delete_tensors_from_dict` and `pop` #327

Closed Li-Qingyun closed 6 months ago

Li-Qingyun commented 6 months ago

Hi, it seems that the pop and delete_tensors_from_dict operation results in missing of some samples.

What is this operation for?

The operations includes:

delete_tensors_from_dict(batch_mimicit)
delete_tensors_from_dict(
    {
        "other": [
            images,
            input_ids,
            attention_mask,
            labels,
        ]
    }
)

I met empty net_input at the tail of an epoch training.

image

But the logic in collate_fn seems that it at least has input_ids and attention_masks.

I has not found why the net_input is empty.

but when i delete delete_tensors_from_dict in my debugging script (pass the forward/backward process, only keep data loading to speed up err raising), the bug disappears.

It is strange.

Additionally, the actual bug is nccl timeout ...[Rank x] Watchdog caught collective operation timeout: WorkNCCL.... I found the true err by adding this:

def train_one_epoch(args, model, epoch, mimicit_loaders, tokenizer, optimizer, lr_scheduler, device_id, accelerator, wandb):
    dataloader_iterators = [cycle(dataloader) for dataloader in mimicit_loaders]
    ......

    for num_steps in tqdm(range(args.total_training_steps), disable=args.rank != 0, initial=(epoch * num_batches_per_epoch)):
        ......

        #### MIMIC-IT FORWARD PASS ####
        try:
            net_input = batch_mimicit["net_input"]
            images = net_input["patch_images"].to(device_id, non_blocking=True)
            input_ids = net_input["input_ids"].to(device_id, non_blocking=True)
            attention_mask = net_input["attention_masks"].to(device_id, non_blocking=True)
            labels = None  # placeholder to avoid error
            # net_input = batch_mimicit.pop("net_input")
            # images = net_input.pop("patch_images").to(device_id, non_blocking=True)
            # input_ids = net_input.pop("input_ids").to(device_id, non_blocking=True)
            # attention_mask = net_input.pop("attention_masks").to(device_id, non_blocking=True)
            # labels = None  # placeholder to avoid error
        except Exception as err:
            print(f"\n\n\n{err}\n")
            print(batch_mimicit)
            exit()

        ......

            try:
                loss_mimicit = forward_pass(
                    args,
                    model,
                    tokenizer,
                    images,
                    input_ids,
                    attention_mask,
                    labels,
                    device_id,
                    autocast_type,
                    batch_mimicit,
                )
            except Exception as err:
                print(f"\nforward_pass\n")
                print(f"\n\n\n{err}\n")
                print(batch_mimicit)
                exit()
            ......

            # delete_tensors_from_dict(batch_mimicit)
            # delete_tensors_from_dict(
            #     {
            #         "other": [
            #             images,
            #             input_ids,
            #             attention_mask,
            #             labels,
            #         ]
            #     }
            # )

        ......

Maybe cycle(dataloader) makes some sample appeared again? but their value was deleted by the func?