huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
31.95k stars 4.73k forks source link

eva2_mim model, training not converging #1759

Closed keertika-11 closed 1 year ago

keertika-11 commented 1 year ago

Hi, I am retraining the eva2mim model, but the model is not converging, I get 0% test Accuracy To Reproduce

from timm.data.transforms_factory import create_transform
from timm.data import create_dataset, create_loader
import torchvision
from matplotlib import pyplot as plt
import numpy as np
import torch
import timm.optim
from torchvision import transforms
from tqdm import tqdm

train_ds = create_dataset(name='v1', root='1946_train_val/train', is_training=True, transform=create_transform(448))
train_loader = create_loader(train_ds,
                             input_size=(3,448,448),
                             batch_size=32,
                             use_prefetcher=False,
                             is_training=True,
                             no_aug=True,
                             num_workers=1,
                             interpolation='bicubic',
                             device=torch.device('cuda'))

val_ds = create_dataset(name='v1', root='1946_train_val/validation/', is_training=False, transform=create_transform(448))
val_loader = create_loader(val_ds,
                             input_size=(3,448,448),
                             batch_size=1,
                             use_prefetcher=False,
                             is_training=False,
                             no_aug=True,
                            device=torch.device('cuda'), 
                            interpolation='bicubic')

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in tqdm(dataloader):

            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print((100*correct))
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    "Creating CoAtNet Model"
    model = timm.create_model(""eva02_large_patch14_448.mim_m38m_ft_in22k_in1k", pretrained=True, num_classes=1926, scriptable=True,
                              exportable=True)
    model = torch.nn.DataParallel(model)
    model.to(device)
    data_config = timm.data.resolve_model_data_config(model)
    print(data_config)
    "Creating optimizer"
    updates_per_epoch = len(train_loader)
    optimizer = timm.optim.create_optimizer_v2(model, opt='adam', lr=0.0001)
    lr_scheduler, num_epochs = timm.scheduler.create_scheduler_v2( optimizer, sched: str = 'cosine' ,
                                                          num_epochs: int = 10, decay_epochs: int = 7,
                                                            decay_milestones: typing.List[int] = (3, 5, 7),
                                                          decay_rate: float = 0.1,
                                                          warmup_lr: float = 1e-05,
                                                            warmup_epochs: int = 0,
                                                           step_on_epochs: bool = True,
                                                            updates_per_epoch: int = 0 )
    "creating Loss function"
    loss_fn = torch.nn.CrossEntropyLoss()
    for epoch in range(10):
        size = len(train_loader.dataset)
        for batch,(inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            # print(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if batch % 100 == 0: 
                loss, current = loss.item(), batch * len(inputs)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                # print(end-start ," time taken for 1 step")

        test_loop(val_loader, model,loss_fn)
        torch.save(model, 'eva_vit_l14_448_1946'+'epoch-{}.pt'.format(epoch))

    model = model.eval()
    # saving the entire model
    torch.save(model,"vit_l_p14_rcp_1925.pt")

what am i doing wrong?

Screenshot 2023-04-07 at 15 40 50

1.0.221_cudnn8.0.3_0]

Desktop (please complete the following information):

rwightman commented 1 year ago

@keertika-11 nothing wrong with the model, has been tested. I can't get into detail debugging other's train scripts or hparams, but you LR is 1-2 orders of magnitude too large for that batch size and I'd never train these models without gradient clipping. Moving to dicussions in case anyone else has suggestions as this isn't a bug...