cvlab-stonybrook / Scanpath_Prediction

Predicting Goal-directed Human Attention Using Inverse Reinforcement Learning (CVPR2020)
MIT License
97 stars 22 forks source link

The results cannot be reproduced by me #16

Closed Mikiloo closed 2 years ago

Mikiloo commented 3 years ago

I have trained the model with the hparms provided, but the learning rate and some other parameters do not match the supplementary material of the paper, it may be the reason i cannot reproduce the model like you provided, could you update the haprams?thanks.

ouyangzhibo commented 3 years ago

I have trained the model with the hparms provided, but the learning rate and some other parameters do not match the supplementary material of the paper, it may be the reason i cannot reproduce the model like you provided, could you update the haprams?thanks.

Hi @Mikiloo, can you please provide more info for me to diagnose the problem? Does the return of PPO converge at all?

Mikiloo commented 3 years ago

是这样子的。就是用代码里面的超参数,鉴别器的loss是缓慢下降,生成器的loss一直在震荡,return of PPO也是在振荡。

ouyangzhibo commented 3 years ago

是这样子的。就是用代码里面的超参数,鉴别器的loss是缓慢下降,生成器的loss一直在震荡,return of PPO也是在振荡。

I re-ran the code but did not observe the problem you described. But I updated two or three parameters (disabled the lr scheduler and set the entropy coefficient from 0.1 to 0.01), you can check if it makes any difference.

Normally, the return of PPO should be like this image

It is normal that the loss of the generator oscillates as the discriminator changes. The loss (real and fake) of the discriminator should be like this image

Hope this helps!

Mikiloo commented 3 years ago

是这样子的。就是用代码里面的超参数,鉴别器的loss是缓慢下降,生成器的loss一直在震荡,return of PPO也是在振荡。

I re-ran the code but did not observe the problem you described. But I updated two or three parameters (disabled the lr scheduler and set the entropy coefficient from 0.1 to 0.01), you can check if it makes any difference.

Normally, the return of PPO should be like this image

It is normal that the loss of the generator oscillates as the discriminator changes. The loss (real and fake) of the discriminator should be like this image

Hope this helps!

Hi, I tried again with the new hparams, but i cannot get the results like the model you provided in test.py, i have the lower scores in multimatch merics with 0.805,0.558,0.765,0.585 as before.

ouyangzhibo commented 3 years ago

Hi, I tried again with the new hparams, but i cannot get the results like the model you provided in test.py, i have the lower scores in multimatch merics with 0.805,0.558,0.765,0.585 as before.

Not sure why you are getting these numbers, but can you first confirm if the training loss behaved normally? If not, the model is not converging then it could be the reason why you are getting these numbers. Also, be sure to check metrics as well, prob. mismatch and sequence score for example.

Doch88 commented 3 years ago

是这样子的。就是用代码里面的超参数,鉴别器的loss是缓慢下降,生成器的loss一直在震荡,return of PPO也是在振荡。

Could you write this message in English, please?

Mikiloo commented 3 years ago

是这样子的。就是用代码里面的超参数,鉴别器的loss是缓慢下降,生成器的loss一直在震荡,return of PPO也是在振荡。

Could you write this message in English, please?

the loss of discriminator continues to decline, but the loss of generator and the return of PPO oscillates.

ouyangzhibo commented 3 years ago

Hi, I tried again with the new hparams, but i cannot get the results like the model you provided in test.py, i have the lower scores in multimatch merics with 0.805,0.558,0.765,0.585 as before.

Not sure why you are getting these numbers, but can you first confirm if the training loss behaved normally? If not, the model is not converging then it could be the reason why you are getting these numbers. Also, be sure to check metrics as well, prob. mismatch and sequence score for example.

Since there is no further activity, so I am closing this issue. Please feel free to reopen it if needed.

StoyanVenDimitrov commented 2 years ago

Hi,

I trained with the hparams you'be submitted on master + left out the lr_scheduler.step(). I get the overall training behaviour right, but not with the values from your plots above. Also, with the model trained this way, I get multimatch scores of [0.51449315 0.61993859 0.23987709 0.32104297]. For testing I took your script from issues #12 and just added metrics.compute_mm(human_scanpaths_valid, predictions, hparams.Data.im_w, hparams.Data.im_h) to get the scores. Do you maybe have changed the hyperparameters to get the results you show?

I also wondered why in metrics.comupte_mm() you take the mean of the multimatch scores with all human scanpaths instead of the highest one?

Screenshot 2022-05-19 at 07 50 02 Screenshot 2022-05-19 at 07 49 33 Screenshot 2022-05-19 at 07 48 15
ouyangzhibo commented 2 years ago

@StoyanVenDimitrov I am reopening this issue for further discussion. To help diagnose the problem, can you share more training and validation stats? For example, the sequence score and multimatch scores on the validation set. Also, as a sanity check step, you can try the pre-trained model and see if that works normally on the test set.

StoyanVenDimitrov commented 2 years ago

Hi, I noticed that I get very low target fixation prob [0. 0. 0.00035389 0.00035389 0.00035389 0.00035389 0.00035389]. So basically no of the annotated scanpaths in the test set reaches the target bounding box. The same target bounding boxes data works fine with the train and valid scanpaths. Could you please verify that these are the right test annotations: https://drive.google.com/file/d/1nSORDDMiz6uv5mhgvMfBVpMyJUSnyYFl/view?usp=sharing ?

ouyangzhibo commented 2 years ago

Not sure if you've noticed this in the README.md

Note that in this paper we rescaled the images to 512x320 as well as the fixation locations. The original COCO-Search18 dataset was collected on a 1680x1050 display. The computed belief maps and rescaled fixations used in this paper can be found at this link.

You need to rescale the raw annotations (both the bounding boxes and fixations) in the test set to 512x320.

StoyanVenDimitrov commented 2 years ago

Thank you, I definitely skipped this part. Now I have [0.00793765 0.42771553 0.74378315 0.85568454 0.88808041 0.89897851 0.90245387], which is better; the final results in multimatch are [0.88061413 0.68688441 0.85857821 0.88670973], which is slightly worse than what you reported. But still, 10% of the test scanpaths don't reach the target. Is this normal or do you get other results? This is my test script, largely repeating your script from #12 :

"""Test script.
Usage:
  tester.py <hparams> <dataset_root> <checkpoint_dir> [--cuda=<id>]
  tester.py -h | --help

Options:
  -h --help     Show this screen.
  --cuda=<id>   id of the cuda device [default: 0].
"""
import os, json
import torch
import numpy as np
from tqdm import tqdm
from docopt import docopt
from os.path import join
from dataset import process_data
from irl_dcb.config import JsonConfig
from torch.utils.data import DataLoader
from irl_dcb.models import LHF_Policy_Cond_Small
from irl_dcb.environment import IRL_Env4LHF
from irl_dcb import metrics
from irl_dcb import utils
torch.manual_seed(42619)
np.random.seed(42619)

def gen_scanpaths(generator,
                  env_test,
                  test_img_loader,
                  hparams,
                  num_sample=10):
    patch_num = hparams.Data.patch_num
    max_traj_len = hparams.Data.max_traj_length
    all_actions = []
    for i_sample in range(num_sample):
        progress = tqdm(test_img_loader,
                        desc='trial ({}/{})'.format(i_sample + 1, num_sample))
        for i_batch, batch in enumerate(progress):
            env_test.set_data(batch)
            img_names_batch = batch['img_name']
            cat_names_batch = batch['cat_name']
            with torch.no_grad():
                env_test.reset()
                trajs = utils.collect_trajs(env_test,
                                            generator,
                                            patch_num,
                                            max_traj_len,
                                            is_eval=True,
                                            sample_action=True)
                all_actions.extend([(cat_names_batch[i], img_names_batch[i],
                                     'present', trajs['actions'][:, i])
                                    for i in range(env_test.batch_size)])

    scanpaths = utils.actions2scanpaths(all_actions, patch_num, hparams.Data.im_w, hparams.Data.im_h)
    utils.cutFixOnTarget(scanpaths, bbox_annos)

    return scanpaths

if __name__ == '__main__':
    args = docopt(__doc__)
    device = torch.device('cuda:{}'.format(args['--cuda']))
    hparams = args["<hparams>"]
    dataset_root = args["<dataset_root>"]
    checkpoint = args["<checkpoint_dir>"]
    hparams = JsonConfig(hparams)
    bbox_annos = np.load(join(dataset_root, 'bbox_annos.npy'),
                         allow_pickle=True).item()
    with open(join(dataset_root,
                   'human_scanpaths_TP_trainval_train.json')) as json_file:
        human_scanpaths_train = json.load(json_file)

    # ! coco test data instead of validation set
    with open(join(dataset_root,
                   'coco_test.json')) as json_file:
        human_scanpaths_test = json.load(json_file)

    human_scanpaths_test = list(
            filter(lambda x: x['correct'] == 1, human_scanpaths_test))

    for scanpath in human_scanpaths_test:
        scanpath['X'] = [x * 512/1680 for x in scanpath['X']]
        scanpath['Y'] = [x * 320/1050 for x in scanpath['Y']]

    # dir of pre-computed beliefs
    DCB_dir_HR = join(dataset_root, 'DCBs/HR/')
    DCB_dir_LR = join(dataset_root, 'DCBs/LR/')
    data_name = '{}x{}'.format(hparams.Data.im_w, hparams.Data.im_h)

    # process fixation data
    dataset = process_data(human_scanpaths_train, human_scanpaths_test,
                           DCB_dir_HR,
                           DCB_dir_LR,
                           bbox_annos,
                           hparams)
    img_loader = DataLoader(dataset['img_valid'],
                            batch_size=64,
                            shuffle=False,
                            num_workers=16)
    print('num of test images =', len(dataset['img_valid']))

    # load trained model
    input_size = 134  # number of belief maps
    task_eye = torch.eye(len(dataset['catIds'])).to(device)
    generator = LHF_Policy_Cond_Small(hparams.Data.patch_count,
                                      len(dataset['catIds']), task_eye,
                                      input_size).to(device)
    state = torch.load(join(checkpoint, 'orig_trained_generator.pkg'), map_location=device)
    generator.load_state_dict(state['model'])

    # generator.load_state_dict(
    #     torch.load(join(checkpoint, 'trained_generator.pkg'),
    #                map_location=device))

    generator.eval()

    # build environment
    env_test = IRL_Env4LHF(hparams.Data,
                           max_step=hparams.Data.max_traj_length,
                           mask_size=hparams.Data.IOR_size,
                           status_update_mtd=hparams.Train.stop_criteria,
                           device=device,
                           inhibit_return=True)

    # generate scanpaths
    print('sample scanpaths (10 for each testing image)...')
    predictions = gen_scanpaths(generator,
                                env_test,
                                img_loader,
                                hparams,
                                num_sample=10)

    # compute multimatch
    res = metrics.compute_mm(human_scanpaths_test, predictions, hparams.Data.im_w, hparams.Data.im_h)
    print('Multimatch done: ', res)
ouyangzhibo commented 2 years ago

[0.00793765 0.42771553 0.74378315 0.85568454 0.88808041 0.89897851 0.90245387]

I suppose these numbers are the target fixation probs of the ground-truth scanpaths, right? It is normal because there are three cases: scanpaths actually find the target but are longer than 6 fixations; "incorrect" scanpaths; and correct scanpaths without a fixation located within the target bounding box. But note that we exclude the latter two cases from training and testing.

StoyanVenDimitrov commented 2 years ago

ok, I didn't exclude correct scanpaths without a fixation located within the target bounding box. But nevertheless, I got mm scores of [0.89088992 0.69892884 0.87192403 0.89618582] with my trained model, which is on par with your reported results.

quangdaist01 commented 2 years ago

I have excluded the latter two cases when loading the ground truths and predicted scanpaths. And I use the checkpoint in the repo, but the result is not the same as reported. Please have a look. Thank you!

The result in the paper: image My result: Probability Mismatch: 0.8589589478932521 Scanpath ratio : 0.857290438249376 Multimatch done: [0.88629528 0.71560808 0.86273588 0.91471171]

The code:


"""Test script.
Usage:
  test.py <hparams> <dataset_root> <checkpoint_dir> [--cuda=<id>]
  test.py -h | --help
Options:
  -h --help     Show this screen.
  --cuda=<id>   id of the cuda device [default: 0].
"""
import os, json
import torch
import numpy as np
from tqdm import tqdm
from os.path import join
from dataset import process_data
from irl_dcb.config import JsonConfig
from torch.utils.data import DataLoader
from irl_dcb.models import LHF_Policy_Cond_Small
from irl_dcb.environment import IRL_Env4LHF
from irl_dcb import metrics
from irl_dcb import utils
from irl_dcb.utils import compute_search_cdf

torch.manual_seed(42619)
np.random.seed(42619)

def gen_scanpaths(generator,
                  env_test,
                  test_img_loader,
                  hparams,
                  num_sample=10):
    patch_num = hparams.Data.patch_num
    max_traj_len = hparams.Data.max_traj_length
    all_actions = []
    for i_sample in range(num_sample):
        progress = tqdm(test_img_loader,
                        desc='trial ({}/{})'.format(i_sample + 1, num_sample))
        for i_batch, batch in enumerate(progress):
            env_test.set_data(batch)
            img_names_batch = batch['img_name']
            cat_names_batch = batch['cat_name']
            with torch.no_grad():
                env_test.reset()
                trajs = utils.collect_trajs(env_test,
                                            generator,
                                            patch_num,
                                            max_traj_len,
                                            is_eval=True,
                                            sample_action=True)
                all_actions.extend([(cat_names_batch[i], img_names_batch[i],
                                     'present', trajs['actions'][:, i])
                                    for i in range(env_test.batch_size)])

    scanpaths = utils.actions2scanpaths(all_actions, patch_num, hparams.Data.im_w, hparams.Data.im_h)
    utils.cutFixOnTarget(scanpaths, bbox_annos)

    return scanpaths

device = torch.device('cpu')
hparams = r'C:\Users\MSI I5\PycharmProjects\Scanpath_Prediction\hparams\coco_search18.json'
dataset_root = r'C:\Users\MSI I5\PycharmProjects\Scanpath_Dataset\processed-20220328T090355Z-001\processed'
dataset_root = r'D:\OneDrive - Trường ĐH CNTT - University of Information Technology\Máy tính\DCBs'
checkpoint = r'C:\Users\MSI I5\PycharmProjects\Scanpath_Prediction\trained_models'
hparams = JsonConfig(hparams)
bbox_annos = np.load(join(dataset_root, 'bbox_annos.npy'),
                     allow_pickle=True).item()
with open(join(dataset_root,
               'human_scanpaths_TP_trainval_train.json')) as json_file:
    human_scanpaths_train = json.load(json_file)

# ! coco test data instead of validation set
with open(r'C:\Users\MSI I5\PycharmProjects\Scanpath_Prediction\coco_search18_fixations_TP_test.json') as json_file:
    human_scanpaths_test = json.load(json_file)

human_scanpaths_test = list(filter(lambda x: x['correct'] == 1, human_scanpaths_test))
human_scanpaths_test = list(filter(lambda item: len(item['X']) <= 6, human_scanpaths_test))

for scanpath in human_scanpaths_test:
    scanpath['X'] = [x * 512 / 1680 for x in scanpath['X']]
    scanpath['Y'] = [x * 320 / 1050 for x in scanpath['Y']]

# dir of pre-computed beliefs
DCB_dir_HR = join(dataset_root, 'DCBs/HR/')
DCB_dir_LR = join(dataset_root, 'DCBs/LR/')
data_name = '{}x{}'.format(hparams.Data.im_w, hparams.Data.im_h)

# process fixation data
dataset = process_data(human_scanpaths_train, human_scanpaths_test,
                       DCB_dir_HR,
                       DCB_dir_LR,
                       bbox_annos,
                       hparams)
img_loader = DataLoader(dataset['img_valid'],
                        batch_size=64,
                        shuffle=False,
                        num_workers=0)
print('num of test images =', len(dataset['img_valid']))

# load trained model
input_size = 134  # number of belief maps
task_eye = torch.eye(len(dataset['catIds'])).to(device)
generator = LHF_Policy_Cond_Small(hparams.Data.patch_count,
                                  len(dataset['catIds']), task_eye,
                                  input_size).to(device)
state = torch.load(join(checkpoint, 'trained_generator.pkg'), map_location=device)
generator.load_state_dict(state['model'])

generator.eval()

# build environment
env_test = IRL_Env4LHF(hparams.Data,
                       max_step=hparams.Data.max_traj_length,
                       mask_size=hparams.Data.IOR_size,
                       status_update_mtd=hparams.Train.stop_criteria,
                       device=device,
                       inhibit_return=True)

# generate scanpaths
print('sample scanpaths (10 for each testing image)...')
predictions = gen_scanpaths(generator,
                            env_test,
                            img_loader,
                            hparams,
                            num_sample=10)
predictions = list(filter(lambda item: len(item['X']) <= 6, predictions))

# write predictions.txt
with open('predictions.txt', 'w') as f:
    f.write(str(predictions))

# compute multimatch
res = metrics.compute_mm(human_scanpaths_test, predictions, hparams.Data.im_w, hparams.Data.im_h)
SPratio = metrics.compute_avgSPRatio(predictions, bbox_annos, 10)
# SPratio = metrics.compute_avgSPRatio(predictions, human_scanpaths_test, 10)
print('Scanpath ratio : ', SPratio)
print('Multimatch done: ', res)

# probability mismatch
human_mean_cdf, _ = compute_search_cdf(human_scanpaths_test, bbox_annos, hparams.Data.max_traj_length)
mean_cdf, _ = compute_search_cdf(predictions, bbox_annos, hparams.Data.max_traj_length)
sad = np.sum(np.abs(human_mean_cdf - mean_cdf))
print('Probability Mismatch: ', sad)
StoyanVenDimitrov commented 2 years ago

Hi, I have also one question about the "Human" results in table 2. You say "Human” refers to an oracle method where one searcher’s scanpath is used to predict another searcher’s scanpath". What do you mean exactly? Isn't in the same as the row res = metrics.compute_mm(human_scanpaths_test, human_scanpaths_test, hparams.Data.im_w, hparams.Data.im_h) similar to the script above? Just MM score between human scanpaths? Thank you.

ouyangzhibo commented 2 years ago

@quangdaist123 There are two reasons: 1. During testing you would sample scanpaths according to the model predictions. This is step is subject to the randomness during sampling. 2. We later retrained the model with some hyper-parameter tuning and published the newest pre-trained models. Therefore, it is likely that you would observe slightly better results.

ouyangzhibo commented 2 years ago

Hi, I have also one question about the "Human" results in table 2. You say "Human” refers to an oracle method where one searcher’s scanpath is used to predict another searcher’s scanpath". What do you mean exactly? Isn't in the same as the row res = metrics.compute_mm(human_scanpaths_test, human_scanpaths_test, hparams.Data.im_w, hparams.Data.im_h) similar to the script above? Just MM score between human scanpaths? Thank you.

Specifically, suppose a scanpath from subject 1 is our target, the scanpaths of the rest 9 subjects are our predictions (similar to our model sampling 10 scanpaths as predictions).