lukemelas / EfficientNet-PyTorch

A PyTorch implementation of EfficientNet
Apache License 2.0
7.91k stars 1.53k forks source link

Getting High Loss Values #219

Open IamSparky opened 4 years ago

IamSparky commented 4 years ago

I am using this b7 model in efficientNet , but getting huge losses , can't able to figure out where I am going wrong

downloading the pretrained model - efficientnet b7

!pip install efficientnet_pytorch

import efficientnet_pytorch

model = efficientnet_pytorch.EfficientNet.from_pretrained('efficientnet-b7')

changing the last layer from 1000 category classifier to 104 flower's catergory classifier

in_features = model._fc.in_features
model._fc = torch.nn.Linear(in_features, 104)

freezing all the layers so as to use the pretrained weights

for param in model.parameters():
        param.requires_grad = False

Leaving the last layer to train

for param in model._fc.parameters():
        param.requires_grad = True

setting up the optimizer , loss func. & scheduler for training

!pip install torchtoolbox
from torchtoolbox.tools import mixup_data, mixup_criterion

#for Stochastic Weight Averaging in PyTorch
from torchcontrib.optim import SWA

base_optimizer = torch.optim.Adam(model._fc.parameters(), lr=1e-4)

optimizer = SWA(base_optimizer, swa_start=5, swa_freq=5, swa_lr=0.05)

loss_fn = torch.nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

finally Training

for epoch in range(epochs):
        print('Epoch ', epoch,'/',epochs-1)
        print('-'*15)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0

            alpha = 0.2

            # Iterate over data.
            for i,(inputs,labels) in enumerate(dataloaders[phase]):
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, alpha)

                # zero the parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):

                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = mixup_criterion(loss_fn, outputs, labels_a, labels_b, lam)

                    # loss = loss_fn(outputs,labels)

                    # we backpropagate to set our learning parameters only in training mode
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data) # (preds == labels.data) as the usage of .data is not recommended, as it might have unwanted side effect.

            # scheduler for weight decay
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / float(dataset_sizes[phase])
            epoch_acc = running_corrects / float(dataset_sizes[phase])

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
    optimizer.swap_swa_sgd()

image

Lg955 commented 3 years ago

Hi, I have two questions: 1. Have you solved the problem? 2. If I want to save the checkpoint for every epoch, where should I put this code "optimizer.swap_swa_sgd()" ?

IamSparky commented 3 years ago

Yes I have solved the problem, thanks

Lg955 commented 3 years ago

I have been troubled for many days. Could I have a look at your new training code? If I can, I will leave my email. Thank you!

IamSparky commented 3 years ago

Cool..here's my notebook https://www.kaggle.com/soumochatterjee/cutmix-flower-classification

Lg955 commented 3 years ago

OK, Thank you!

GabPrato commented 3 years ago

@soumochatterjee would you mind sharing what your problem was and how you fixed it? Thanks.