zwx8981 / LIQE

[CVPR2023] Blind Image Quality Assessment via Vision-Language Correspondence: A Multitask Learning Perspective
MIT License
151 stars 9 forks source link

train issues #22

Open ctxya1207 opened 1 month ago

ctxya1207 commented 1 month ago

how to train on the AIGC dataset

zwx8981 commented 1 month ago

I think the only two things to do are: (1) Write a data loader of the target AIGC dataset (2) Modify the training code by omitting the two auxiliary tasks (scene classification and distortion type identification) if you only want to train with quality labels.

ctxya1207 commented 1 month ago

我认为唯一要做的两件事是:(1)编写目标 AIGC 数据集的数据加载器(2)如果只想使用质量标签进行训练,则通过省略两个辅助任务(场景分类和失真类型识别)来修改训练代码。

thank you, one more question, I want to know which one is the quality score fidelity loss function?

zwx8981 commented 1 month ago

@ctxya1207 I have updated the code by adding a script to enable single-database training of LIQE with quality labels only. See Readme:

python train_liqe_single.py

ctxya1207 commented 1 month ago

def loss_m(y_pred, y): """prediction monotonicity related loss""" assert y_pred.size(0) > 1 # preds = y_pred-(y_pred + 10).t() gts = y.t() - y triu_indices = torch.triu_indices(y_pred.size(0), y_pred.size(0), offset=1) preds = preds[triu_indices[0], triu_indices[1]] gts = gts[triu_indices[0], triu_indices[1]] return torch.sum(F.relu(preds * torch.sign(gts))) / preds.size(0),this loss function Is it computational fidelity loss?

ctxya1207 commented 1 month ago

In fact, I want to use the fidelity loss function when predicting the consistency score in AIGC image quality evaluation. I don't know how to write this function

zwx8981 commented 1 month ago

We have several implementation variants of fidelity loss. By default, we use loss_m4 in our original implementation, which adopts the predicted quality, the number of images sampled from each dataset, and the ground-truth quality as input, and compute the fidelity loss on each dataset and average them into the final loss value.

zwx8981 commented 1 month ago

If you only want the fidelity loss, loss_m3 would be fine. loss_m is an implementation of margin ranking loss, not fidelity loss.

ctxya1207 commented 1 month ago

thank you very much,Can you give me a contact method so that we can communicate better?

zwx8981 commented 1 month ago

feel free to contact me via e-mail zwx8981@sjtu.edu.cn

ctxya1207 commented 1 month ago

def loss_m3(y_pred, y): """prediction monotonicity related loss""" assert y_pred.size(0) > 1 # y_pred = y_pred.unsqueeze(1) y = y.unsqueeze(1) preds = y_pred-y_pred.t() gts = y - y.t()

#signed = torch.sign(gts)

triu_indices = torch.triu_indices(y_pred.size(0), y_pred.size(0), offset=1)
preds = preds[triu_indices[0], triu_indices[1]]
gts = gts[triu_indices[0], triu_indices[1]]
g = 0.5 * (torch.sign(gts) + 1)

constant = torch.sqrt(torch.Tensor([2.])).to(preds.device)
p = 0.5 * (1 + torch.erf(preds / constant))

g = g.view(-1, 1)
p = p.view(-1, 1)

loss = torch.mean((1 - (torch.sqrt(p * g + esp) + torch.sqrt((1 - p) * (1 - g) + esp))))

return loss,In this function, if the size of y is (batch_size, 1), should unsqueeze(1) be removed?
zwx8981 commented 1 month ago

Yes

ctxya1207 commented 1 month ago

python train_liqe_single.py, in this file, total_loss = total_loss + 0.1*refine_loss, what does refine_loss mean, and why is the previous weight 0.1

zwx8981 commented 1 month ago

Sorry, that's an uncleaned code. I've fixed it. Try it again.

ctxya1207 commented 1 month ago

Sorry, that's an uncleaned code. I've fixed it. Try it again.

running_loss = beta running_loss + (1 - beta) total_loss.data.item(),Why is beta set to 0.9?

zwx8981 commented 1 month ago

This is only a momentum factor to compute the moving average loss, which does not affect the training effect.

ctxya1207 commented 1 month ago

num_steps_per_epoch = 200, May I ask if this variable is equivalent to batch_size?

ctxya1207 commented 1 month ago
        print('...............current average best...............')
        print('best average epoch:{}'.format(best_epoch['avg']))
        print('best average result:{}'.format(best_result['avg']))
        for dataset in srcc_dict.keys():
            print_text = dataset + ':' + 'scene:{}, distortion:{}, srcc:{}'.format(
                scene_dict[dataset], type_dict[dataset], srcc_dict[dataset])
            print(print_text)

        print('...............current quality best...............')
        print('best quality epoch:{}'.format(best_epoch['quality']))
        print('best quality result:{}'.format(best_result['quality']))
        for dataset in srcc_dict1.keys():
            print_text = dataset + ':' + 'scene:{}, distortion:{}, srcc:{}'.format(
                scene_dict1[dataset], type_dict1[dataset], srcc_dict1[dataset])
            print(print_text),What is the difference between avg best and  quality best