Open linzengmin opened 6 months ago
I'm confused why this model is saved here. From the code it seems that it has not been trained, it is just a pre-trained model.
torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, os.path.join(checkpoint_dir, "best.pt" ), )
Train code
model = model.to(device) model.eval() loss_fn = DSVDD(model, train_loader, args.cnn, args.gamma_c, args.gamma_d, device) loss_fn = loss_fn.to(device) epochs = 30 params = [{'params' : loss_fn.parameters()},] optimizer = optim.AdamW(params = params, lr = 1e-3, weight_decay = 5e-4, amsgrad = True ) best_pxl_pro = -1 for epoch in tqdm(range(epochs), '%s -->'%(class_name)): r'TEST PHASE' test_imgs = list() gt_mask_list = list() gt_list = list() heatmaps = None loss_fn.train() for (x, _, _) in train_loader: optimizer.zero_grad() p = model(x.to(device)) loss, _ = loss_fn(p) loss.backward() optimizer.step() loss_fn.eval()
I'm confused why this model is saved here. From the code it seems that it has not been trained, it is just a pre-trained model.
Train code