fadel / pytorch_ema

Tiny PyTorch library for maintaining a moving average of a collection of parameters.
MIT License
400 stars 25 forks source link

question about the use of ExponentialMovingAverage #15

Open sevennotmouse opened 3 weeks ago

sevennotmouse commented 3 weeks ago

here is my code: from torch_ema import ExponentialMovingAverage model = ... optimizer = ... scheduler = ...
ema_model = ExponentialMovingAverage(parameters=pg, decay=0.9999)

for epoch in range(args.epochs):# train
    print('epoch:',epoch,'Current learning rate:', optimizer.param_groups[0]['lr'])
    train_loss, train_MAE, tb_writer = train_one_epoch(model=model,
                                            optimizer=optimizer,
                                            data_loader=train_loader,
                                            device=device,
                                            epoch=epoch,
                                            scheduler=scheduler,
                                            csv_filename=args.csv_filename,
                                            tb_writer=tb_writer)

    scheduler.step()
    ema_model.update()
    # validate
    with ema_model.average_parameters():
        val_loss, val_MAE = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)

As shown in the code, I will execute the evaluate function on the validation set after each round of training. I found that the validation results are exactly the same when I set different decay values, why is that? The evaluate function is as follows: @torch.no_grad() def evaluate(model, data_loader, device, epoch): softceloss_function = SoftCrossEntropy() model.eval() data_loader = tqdm(data_loader) for step, data in enumerate(data_loader): images, names, labels = data pred = model(images.to(device)) softlabel = softlabel_function(labels) # a function to convert labels to softlabel loss = softceloss_function(pred, softlabel.to(device)) val_loss,val_MAE = ... # calculate loss and MAE

fadel commented 3 weeks ago

Hi sevennotmouse,

ema_model = ExponentialMovingAverage(parameters=pg, decay=0.9999)

Is pg changing in your example? Where does it come from? If there is a chance that pg does not change, maybe that would explain the behaviour you are seeing.

sevennotmouse commented 2 weeks ago

Thanks for your reply, let me add a clarification to the code. pg is the parameters to be trained in the model (I freeze some of the parameters of the model during training). Here is the detailed codes:

from torch_ema import ExponentialMovingAverage
model = ...
for name, para in model.named_parameters():
    if "blocks" in name or "head" in name:
        para.requires_grad_(True)
    else:
        para.requires_grad_(False)
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=0.01, momentum=0.9, weight_decay=5e-5)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0)
ema_model = ExponentialMovingAverage(parameters=pg, decay=0.9999)

best_MAE=10
save_path=...
for epoch in range(args.epochs):# train
    print('epoch:',epoch,'Current learning rate:', optimizer.param_groups[0]['lr'])
    train_loss, train_MAE, tb_writer = train_one_epoch(model=model,
                                            optimizer=optimizer,
                                            data_loader=train_loader,
                                            device=device,
                                            epoch=epoch,
                                            scheduler=scheduler,
                                            csv_filename=args.csv_filename,
                                            tb_writer=tb_writer)

    scheduler.step()
    ema_model.update()
    # validate
    with ema_model.average_parameters():
        val_loss, val_MAE = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)
    if val_MAE < best_MAE:
        best_MAE=val_MAE
        torch.save(ema_model.state_dict(), save_path) 

The evaluate function is as follows:

@torch.no_grad()
def evaluate(model, data_loader, device, epoch):
    softceloss_function = SoftCrossEntropy()
    model.eval()
    data_loader = tqdm(data_loader)
    for step, data in enumerate(data_loader):
        images, names, labels = data
        pred = model(images.to(device))
        softlabel = softlabel_function(labels) # a function to convert labels to softlabel
        loss = softceloss_function(pred, softlabel.to(device))

    val_loss,val_MAE = ... # calculate loss and MAE
    return val_loss, val_MAE

At the end of each epoch during training, I execute the evaluate function on the validation set. If val_MAE < best_MAE, I want to save the model's checkpoint. After 20 epochs of training, I will select the best model with the best performance on the validation set and test it on the test set.

The results are as follows:

  1. If I don't use the package of pytorch_ema, or evaluate without this code with ema_model.average_parameters(): , the val_MAE of epoch 1,2,3 are 6.185, 5.779 and 5.529, respectively.
  2. If I use the code provided above, ema_model = ExponentialMovingAverage(parameters=pg, decay=0.9999), where decay is set to 0.9999, the val_MAE of epoch 1,2,3 are 6.269, 5.878 and 5.548, respectively. This demonstates the effectiveness of the ema mode.
  3. When I try to set decay to other values such as 0.999(or any other value), ema_model = ExponentialMovingAverage(parameters=pg, decay=0.999) and restart training, I found that the validation results are exactly the same, the val_MAE of epoch 1,2,3 are 6.269, 5.878 and 5.548 respectively.

To summarize, I have two questions:

  1. I wonder why the evaluate results are exactly the same when setting different decay values for training.
  2. Since ema_model.state_dict() is different from model.state_dict(), how to save the checkpoint of ema_model and apply it on test set?

Looking forward to your reply!