sungwool / CFA_for_anomaly_localization

104 stars 22 forks source link

model save #21

Open linzengmin opened 6 months ago

linzengmin commented 6 months ago

I'm confused why this model is saved here. From the code it seems that it has not been trained, it is just a pre-trained model.

          torch.save(
                    {
                        "epoch": epoch,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                    },
                    os.path.join(checkpoint_dir, "best.pt" ),
                )

Train code

        model = model.to(device)
        model.eval()

        loss_fn = DSVDD(model, train_loader, args.cnn, args.gamma_c, args.gamma_d, device)
        loss_fn = loss_fn.to(device)

        epochs = 30
        params = [{'params' : loss_fn.parameters()},]
        optimizer     = optim.AdamW(params        = params, 
                                    lr            = 1e-3,
                                    weight_decay  = 5e-4,
                                    amsgrad       = True )
        best_pxl_pro = -1
        for epoch in tqdm(range(epochs), '%s -->'%(class_name)):
            r'TEST PHASE'

            test_imgs = list()
            gt_mask_list = list()
            gt_list = list()
            heatmaps = None

            loss_fn.train()
            for (x, _, _) in train_loader:
                optimizer.zero_grad()
                p = model(x.to(device))

                loss, _ = loss_fn(p)
                loss.backward()
                optimizer.step()
            loss_fn.eval()