microsoft / unilm

Large-scale Self-supervised Pre-training Across Tasks, Languages, and Modalities
https://aka.ms/GeneralAI
MIT License
19.48k stars 2.48k forks source link

BEiT training loss is not reducing after few epochs #1268

Open senthil-r-10 opened 1 year ago

senthil-r-10 commented 1 year ago

Hi, I have followed the steps in the notebook to train the BEiT model, but the loss is not reducing after few epochs, initial loss is 7.2 and the loss stagnated at 4.3 after few epochs.

I have used 1M documents from the IIT-CDIP dataset, documents are resized to 224x224 using BeitImageProcessor, model: config = BeitConfig(use_relative_position_bias=True, use_mask_token=True) model = BeitForMaskedImageModeling(config)

loss: CrossEntropyLoss optimizer: AdamW(model.parameters(), lr=1e-5, weight_decay=0.05) lrscheduler: ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=0, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-6, eps=1e-08, verbose=True) training code:

` def train(model, encoder, train_dataloader, device, rank, optimizer, loss_fn, scheduler, epochs, save_model_dir):

window_size = model.module.beit.embeddings.patch_embeddings.patch_shape
torch.set_grad_enabled(True)
num_masking_patches = 310 #75
max_mask_patches_per_block = None
min_mask_patches_per_block = 32 #16
# generating mask for the corresponding image
mask_generator = MaskingGenerator(
            window_size, num_masking_patches=num_masking_patches,
            max_num_patches=max_mask_patches_per_block,
            min_num_patches=min_mask_patches_per_block,
        )
max_grad_norm = 3.0 # for base BEiT model

for epoch in range(1, epochs, 1):
    total_loss = 0
    st_time = time.time()

    for step_count, batch in enumerate(train_dataloader, 1):
        print(step_count, end="\r")
        optimizer.zero_grad()
        pixel_values , pixel_values_dall_e = batch[0], batch[1]
        bool_masked_pos = mask_generator()
        batch_bool_masked_pos = np.zeros((pixel_values_dall_e.shape[0], bool_masked_pos.shape[0], bool_masked_pos.shape[1]))

        for idx, i in enumerate(range(pixel_values_dall_e.shape[0])):
            bool_masked_pos = mask_generator()
            batch_bool_masked_pos[idx, :, :] = bool_masked_pos
        batch_bool_masked_pos = torch.from_numpy(batch_bool_masked_pos)#.unsqueeze(0)

        with torch.no_grad():
            z_logits = encoder(pixel_values_dall_e.to(device))
            input_ids = torch.argmax(z_logits, axis=1).flatten(1)
            batch_bool_masked_pos = batch_bool_masked_pos.flatten(1).to(torch.bool)
            labels = input_ids[batch_bool_masked_pos]

        pixel_values = pixel_values.to(device)
        batch_bool_masked_pos = batch_bool_masked_pos.to(device)
        labels = labels.to(device)
        outputs = model(pixel_values, batch_bool_masked_pos)

        loss = loss_fn(outputs.logits[batch_bool_masked_pos], labels)
        total_loss += loss.item()
        loss.backward()

        # Perform gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

    scheduler.step(total_loss/step_count)

`

donglixp commented 1 year ago

The model is still learning. You might use the fine-tuning performance as a more robust indicator.

senthil-r-10 commented 1 year ago

Is increasing the input size from 224x224 to 384x384 or 672x672 with a patch size of 16x16 will help the model to converge? In the paper, they mentioned BEiT(L) 384x384 model performs better than BEiT(L)224x224.