tientrandinh / Revisiting-Reverse-Distillation

(CVPR 2023) Revisiting Reverse Distillation for Anomaly Detection
MIT License
116 stars 25 forks source link

running so slow #18

Open sevenactors opened 3 months ago

sevenactors commented 3 months ago

i want to test the code on my device, but it seems to be very slow. The gpu memory is in use but at the same time cpu usage is also high. I check the model and the input tensor ,they are all on the cuda device not cpu, so i was confused why it use cpu to compute and run so slow.

iulianzorila commented 3 months ago

At train time I experienced the same issue even though I am using GPU. After debugging the code I noticed that the slowest step is computing the proj_loss, so it might be due to the sinkhorn computation in the utils_train.py.

for i, (img,img_noise,_) in enumerate(train_dataloader):
            img = img.to(device)
            img_noise = img_noise.to(device)
            inputs = encoder(img)
            inputs_noise = encoder(img_noise)

            (feature_space_noise, feature_space) = proj_layer(inputs, features_noise = inputs_noise)

            #####
            L_proj = proj_loss(inputs_noise, feature_space_noise, feature_space) ### (Sinkhorn computation) ###  
            #####

            outputs = decoder(bn(feature_space))#bn(inputs))
            L_distill = loss_fucntion(inputs, outputs)
            loss = L_distill + pars.weight_proj * L_proj
            loss.backward()
            if (i + 1) % accumulation_steps == 0:
                optimizer_proj.step()
                optimizer_distill.step()
                # Clear gradients
                optimizer_proj.zero_grad()
                optimizer_distill.zero_grad()

            total_loss_running += loss.detach().cpu().item()
            loss_proj_running += L_proj.detach().cpu().item()
            loss_distill_running += L_distill.detach().cpu().item()

This computation depends on the geomloss.SamplesLoss, I have tried using different backends, but the only one which seems to work for me is the backend='tensorized'