xiangjieSui / ScanDMM

[2023-CVPR] ScanDMM: A Deep Markov Model of Scanpath Prediction for 360-degree Images
18 stars 2 forks source link

want evaluation metrics code #1

Open lsztzp opened 10 months ago

lsztzp commented 10 months ago

can you give the code for the evaluation metrics?

xiangjieSui commented 10 months ago

please referringto https://github.com/rAm1n/saliency/blob/master/metrics/metrics.py
thanks

MarcWong commented 10 months ago

please referringto https://github.com/rAm1n/saliency/blob/master/metrics/metrics.py thanks

Dear authors, thx for your response, but could you provide more details while executing the evaluation? Did you evaluate scanpaths in the image space, that is, (x,y) in pixel? What threshold you have used for the REC and DET metrics?

xiangjieSui commented 10 months ago

please referringto https://github.com/rAm1n/saliency/blob/master/metrics/metrics.py thanks

Dear authors, thx for your response, but could you provide more details while executing the evaluation? Did you evaluate scanpaths in the image space, that is, (x,y) in pixel? What threshold you have used for the REC and DET metrics?

Sure, we evaluate scanpaths in the image space.
We set the threshold as 2 * 6 for an image with a size of (128, 256)

You might change the threshold for other image sizes, please read the paper for details:

On metrics for measuring scanpath similarity Ramin Fahimi, Neil D.B Bruce https://link.springer.com/article/10.3758/s13428-020-01441-0

lsztzp commented 9 months ago

Hello author, thank you very much for your help earlier. I found that when processing the three data sets other than Sitzmann, the length of their scanning paths was very inconsistent, and the time was not continuous, which could not be simply processed at 1HZ. I would like to know some more detailed information when processing the other three data sets, and know the length of the generated path when verifying their data. I'd better hope that the author can provide the corresponding data processing code, thank you!

xiangjieSui commented 9 months ago

Hello author, thank you very much for your help earlier. I found that when processing the three data sets other than Sitzmann, the length of their scanning paths was very inconsistent, and the time was not continuous, which could not be simply processed at 1HZ. I would like to know some more detailed information when processing the other three data sets, and know the length of the generated path when verifying their data. I'd better hope that the author can provide the corresponding data processing code, thank you!

Yes, some of them are complex to process to 1 hz strictly, so actually, what we done was to \approx 1 hz. Sure, I could provide the codes, but I had not cleaned them, so it might a bit difficult to read. The codes will be pasted here in next few days. I have been busy recently, please remind me if I forget.

xiangjieSui commented 9 months ago
class Salient360_Dataset():
    def __init__(self):
        super().__init__()
        self.images_path = config.dic_Salient360['IMG_PATH']
        self.gaze_path = config.dic_Salient360['GAZE_PATH']
        self.test_set = config.dic_Salient360['TEST_SET']

    def create_tree(self, img_list):
        mode = {
            'image': [],
            'scanpaths': []
        }
        dic = {'train': {}, 'test': {}, 'info': {}}
        for img in img_list:
            if img in self.test_set:
                dic['train'][img] = copy.deepcopy(mode)
            else:
                dic['test'][img] = copy.deepcopy(mode)
        return dic

    def data_processing(self):
        info = create_info()
        image_name_list = []
        for file_name in os.listdir(self.images_path):
            image_name_list.append(file_name.split('.')[0])
        dic = self.create_tree(image_name_list)
        for file_name in tqdm(os.listdir(self.images_path)):
            image_name = file_name.split('.')[0]
            its_index = image_name.split('_')[0].split('P')[1]
            type = ('test', 'train')[image_name in self.test_set]

            ''' 
            Notice: Salient360! recorded left (L) and right (R) eyes data. 
                    Intuitively, the data of L and R should be very close since two eyes 
                    would focus on a specific region in the same time. But we found they are not.
                    We observed that the ''head movement data'' approximate to the average of L and R.
                    So we treated the head movement data as the gaze data.
            '''

            gaze_file = self.gaze_path + 'Hscanpath_' + its_index + '.txt'
            f = open(gaze_file, "r")
            index = 0
            scanpaths, x, y = [], [], []
            for row in f:
                if index == 0:  # pass the 1s line
                    index += 1
                    continue
                # print(row_id, lon, lat, file_name)
                row_id = int(row.split(',')[0])
                lon = float(row.split(',')[1])
                lat = float(row.split(',')[2])

                if (row_id + 1) // 100 == 1:  # next user
                    _gaze = np.concatenate(
                        (np.array(y).reshape(-1, 1), np.array(x).reshape(-1, 1)), axis=1)
                    scanpaths.append(suppor_lib.latlontoxyz(
                        suppor_lib.plane2sphere(torch.from_numpy(_gaze))).numpy())
                    x, y = [], []
                elif (row_id + 4) % 4 == 0:  # sampling 25 points from 100 points
                    x.append(lon)
                    y.append(lat)
                else:
                    continue
            scanpaths = np.array(scanpaths)
            image = image_process(
                self.images_path + file_name, need_ratio=False)

            info[type]['num_image'] += 1
            info[type]['num_scanpath'] += scanpaths.shape[0]
            dic[type][image_name]['scanpaths'] = scanpaths
            dic[type][image_name]['image'] = image

        info['train']['scanpath_length'], info['test']['scanpath_length'] = 25, 25
        dic['info'] = info

        return dic

class AOI_Dataset(): 
    def __init__(self):
        super().__init__()
        self.images_path = config.dic_AOI['IMG_PATH']
        self.gaze_path = config.dic_AOI['GAZE_PATH']
        self.test_set = config.dic_AOI['TEST_SET']

    def create_tree(self, img_list):
        mode = {
            'image': [],
            'scanpaths': []
        }
        dic = {'train': {}, 'test': {}, 'info': {}}
        for img in img_list:
            full_name = img + '.jpg'
            if full_name in self.test_set:
                dic['test'][img] = copy.deepcopy(mode)
            else:
                dic['train'][img] = copy.deepcopy(mode)
        return dic

    def data_processing(self):
        info = create_info()
        image_name_list = []
        for file_name in os.listdir(self.images_path):
            image_name_list.append(file_name.split('.')[0])
        dic = self.create_tree(image_name_list)
        for file_name in tqdm(os.listdir(self.images_path)):
            image_name = file_name.split('.')[0]
            # type = ('train', 'test')[file_name in self.test_set]
            gaze_file = self.gaze_path + image_name + '.txt'
            f = open(gaze_file, "r")
            index = 0
            _scanpaths = {}
            for row in f:
                if index == 0:  # pass the 1s line
                    index += 1
                    continue
                subject_id = row.split(' ')[0]
                lon = float(row.split(' ')[-2])
                # change to [bottom - top: 90 -> -90]
                lat = - float(row.split(' ')[-1])
                if not subject_id in _scanpaths:
                    _scanpaths[subject_id] = []
                    _scanpaths[subject_id].append([lat, lon])
                else:
                    _scanpaths[subject_id].append([lat, lon])
            temp = []
            for subject_id in _scanpaths:
                _scanpath = np.array(_scanpaths[subject_id])
                max_scan_length = info[type]['max_scan_length']
                info[type]['num_scanpath'] += 1
                info[type]['max_scan_length'] = (max_scan_length, _scanpath.shape[0])[
                    max_scan_length < _scanpath.shape[0]]
                temp.append(suppor_lib.latlontoxyz(
                    torch.from_numpy(_scanpath)).numpy())

            image = image_process(
                self.images_path + file_name, need_ratio=False)

            dic[type][image_name]['scanpaths'] = temp
            dic[type][image_name]['image'] = image
            info[type]['num_image'] += 1
        info['train']['scanpath_length'] = -1  # non-fix length
        info['test']['scanpath_length'] = -1
        dic['info'] = info

        return dic

class JUFE_Dataset():
    def __init__(self):
        super().__init__()
        self.images_path = config.dic_JUFE['IMG_PATH']
        self.gaze_path = config.dic_JUFE['GAZE_PATH']
        self.mos_path = config.dic_JUFE['MOS_PATH']
        mysplit = torch.load(open('./dataset_split_seed-' + str(seed), 'rb'))
        train_set = mysplit['train']
        test_set = mysplit['val']
        val_set = mysplit['test']
        self.test_set = list(set([test_set[i].split('_')[0]
                             for i in range(len(test_set))]))
        self.train_set = list(
            set([train_set[i].split('_')[0] for i in range(len(train_set))]))
        self.val_set = list(set([val_set[i].split('_')[0]
                            for i in range(len(val_set))]))
        # self.test_set = config.dic_JUFE['TEST_SET']

    def getFileName(self, root, target_file):
        re = []
        res_list = search_dir(root, target_file, re)
        return res_list

    def mod(self, a, b):
        c = a // b
        r = a - c * b
        return r

    def create_tree(self, img_list):
        mode = {
            'image': [],
            'good': {
                '5s': {
                    'mos': [],
                    'scanpaths': []
                },
                '15s': {
                    'mos': [],
                    'scanpaths': []
                }
            },
            'bad': {
                '5s': {
                    'mos': [],
                    'scanpaths': []
                },
                '15s': {
                    'mos': [],
                    'scanpaths': []
                }
            }
        }
        dic = {'train': {}, 'val': {}, 'test': {}, 'info': {}}
        for img in img_list:
            its_index = img.split('_')[0]
            if its_index in self.test_set:
                dic['test'][img] = copy.deepcopy(mode)
            # else:
            elif its_index in self.train_set:
                dic['train'][img] = copy.deepcopy(mode)
            else:
                dic['val'][img] = copy.deepcopy(mode)
        return dic

    def data_processing(self):

        MOS_file = xlrd.open_workbook(self.mos_path).sheet_by_name('Sheet1')
        image_list = MOS_file.col_values(1)[1:]
        start_point = MOS_file.col_values(7)[1:]
        exp_time = MOS_file.col_values(6)[1:]
        MOS = MOS_file.col_values(2)[1:]

        index = 0
        for img in image_list:
            image_list[index] = img.split('.')[0]
            index += 1
        index = 0

        dic = self.create_tree(image_list)
        info = create_info()

        for image in image_list:
            image_name = image.split('.')[0]
            its_index = image_name.split('_')[0]
            # type = ('train', 'test')[its_index in self.test_set]
            if its_index in self.test_set:
                type = 'test'
            elif its_index in self.train_set:
                type = 'train'
            else:
                type = 'val'
            st = ('good', 'bad')[int(start_point[index]) != 1]
            dic[type][image_name][st][exp_time[index]]['mos'] = round(
                np.float(MOS[index]), 4)  # .4f
            info[type]['num_mos'] += 1
            index += 1

        for file_name in tqdm(os.listdir(self.images_path)):
            image_name = file_name.split(".")[0]
            its_index = image_name.split('_')[0]
            # type = ('train', 'test')[its_index in self.test_set]
            if its_index in self.test_set:
                type = 'test'
            elif its_index in self.train_set:
                type = 'train'
            else:
                type = 'val'
            info[type]['num_image'] += 1

            csv_path = self.getFileName(self.gaze_path, image_name + '.csv')
            scanpaths_good, scanpaths_bad = [], []
            for csv in csv_path:
                data = xlrd.open_workbook(csv).sheet_by_name('Sheet1')
                st = ('good', 'bad')['bad' in csv]
                lat, lon = np.array(data.col_values(
                    3)).reshape(-1, 1), np.array(data.col_values(4)).reshape(-1, 1)
                _gaze = np.concatenate((lat, lon), axis=1)
                gaze = []
                # FPS = 1, sample 1 points per second
                fps = 1
                num_sample = fps * 15  # 15-second
                step = int(_gaze.shape[0] / num_sample)
                for j in range(num_sample):
                    gaze.append(_gaze[j * step])
                gaze = suppor_lib.latlontoxyz(
                    torch.from_numpy(np.array(gaze))).numpy()
                if st == 'good':
                    scanpaths_good.append(gaze)
                else:
                    scanpaths_bad.append(gaze)
                info[type]['num_scanpath'] += 1

            scanpaths_good, scanpaths_bad = np.array(
                scanpaths_good), np.array(scanpaths_bad)
            exp_time_5s = int(scanpaths_good.shape[1] / 3)
            image = image_process(
                self.images_path + file_name, need_ratio=False)

            dic[type][image_name]['image'] = image
            dic[type][image_name]['good']['5s']['scanpaths'] = scanpaths_good[:, :exp_time_5s, :]
            dic[type][image_name]['good']['15s']['scanpaths'] = scanpaths_good
            dic[type][image_name]['bad']['5s']['scanpaths'] = scanpaths_bad[:,
                                                                            :exp_time_5s, :]
            dic[type][image_name]['bad']['15s']['scanpaths'] = scanpaths_bad
        info['train']['scanpath_length'], info['test']['scanpath_length'], info['val']['scanpath_length'] = \
            [5, 15], [5, 15], [5, 15]
        dic['info'] = info

        return dic
xiangjieSui commented 9 months ago

Here is the evaluation.py

from pyro.infer import Predictive
from config import image_size
import pyro.contrib.examples.polyphonic_data_loader as poly
from suppor_lib import *
from tqdm import tqdm
from scanpath_metrics import *

def eval(Model=None,
         aoi=None, jxufe=None, salient360=None, sitzmann=None,
         sample_rate=1, with_cnn=False, train_set=None,
         nums_predition_loop=1, num_samples=10, is_train=False, random_st=False, epoch=None, results_log=None):

    def create_random_start_points(num_points):
        y, x = [], []
        for i in range(num_points):
            while True:
                temp = np.random.normal(loc=0, scale=0.2)
                if (temp <= 1) and (temp >= -1):
                    y.append(temp)
                    break
            x.append(np.random.uniform(-1, 1))
        cords = np.vstack((np.array(y) * 90, np.array(x) * 180)).swapaxes(0, 1)
        cords = latlontoxyz(torch.from_numpy(cords))
        return cords

    def summary(samples, gt):
        """
        Transform 3D (x,y,z) to 2D (x, y) ranging in [128, 256]
        """
        obs = None
        obs_gt = None
        for index in range(int(len(samples) / 2)):
            name = 'obs_x_' + str(index + 1)
            temp = samples[name].reshape([-1, 3])
            its_sum = torch.sqrt(temp[:, 0] ** 2 + temp[:, 1] ** 2 + temp[:, 2] ** 2)
            temp = temp / torch.unsqueeze(its_sum, 1)
            if obs is not None:
                obs = torch.cat((obs, torch.unsqueeze(sphere2plane(xyztolatlon(temp), [128, 256]), dim=0)), dim=0)
            else:
                obs = torch.unsqueeze(sphere2plane(xyztolatlon(temp), [128, 256]), dim=0)

        for index in range(gt.shape[1]):
            if obs_gt is not None:
                obs_gt = torch.cat(
                    (obs_gt, torch.unsqueeze(sphere2plane(xyztolatlon(gt[:, index, :]), [128, 256]), dim=0)),
                    dim=0)
            else:
                obs_gt = torch.unsqueeze(sphere2plane(xyztolatlon(gt[:, index, :]), [128, 256]), dim=0)

        return obs, obs_gt

    def pred_loop(data, dataset_name: str):
        """"""
        " Get required information "
        image_index = data["image_index"]
        image_paths = data['img_paths']
        is_aoi = (dataset_name == 'AOI')
        num_loop = image_index.max() + 1

        " Initialize results "
        re1, re2, re3 = [], [], []
        auc, auc_s, nss, cc, sim, kld = [], [], [], [], [], []

        for image_i in tqdm(range(num_loop)):
            " Find related index  "
            test_seq_index = torch.where(image_index == image_i)[0]

            " We use sample_rate to speed up when training "
            sample_index = test_seq_index[0:-1:sample_rate]

            " Get the lengths we need to predict "
            test_seq_lengths = data["sequence_lengths"][sample_index]
            if is_aoi:
                lengths_temp = test_seq_lengths
                test_seq_lengths = (torch.ones_like(lengths_temp) * 22).int()

            " Get the GT scanpaths (we only use the x_1 for initializing state) "
            test_data_sequences = data["sequences"][sample_index]

            " Get the transformed images and original images (the latter are used for drawing saliency map) "
            test_image = data["images"][sample_index]
            img_path = image_paths[image_i]

            " Setting random starting point when test model "
            if random_st:
                random_start_points = create_random_start_points(test_data_sequences.shape[0])
                test_data_sequences[:, 0] = random_start_points

            " Compute the mask "
            test_mask = poly.get_mini_batch_mask(test_data_sequences, test_seq_lengths)

            " Feeding to GPU "
            test_batch = test_data_sequences.cuda()
            test_batch_mask = test_mask.cuda()
            mini_batch_images = test_image.cuda()

            " Prediction "
            with torch.no_grad():
                if with_cnn:
                    image_data = mini_batch_images
                else:
                    image_data = None
                samples = predictive(scanpaths=test_batch,
                                     scanpaths_reversed=None,
                                     mask=test_batch_mask,
                                     scanpath_lengths=None,
                                     images=image_data,
                                     predict=True)

                " Process the scanpaths: 3D coords (x, y, z) -> 2D plane (x, y) "
                pred_summary, gt_summary = summary(samples, test_batch)
                pred_scanpath = pred_summary.cpu().numpy()
                gt_scanpath = gt_summary.cpu().numpy()

                """  For AOI database, we uniformly sampling gaze points from 22 seconds scanpaths,
                 ensuring lengths of produced scanpaths = lengths of GT scanpaths """
                if is_aoi:
                    temp = np.zeros_like(gt_scanpath)
                    for user_i in range(gt_scanpath.shape[1]):
                        current_length = lengths_temp[user_i]
                        step = int(22 / current_length)
                        for sample_i in range(current_length):
                            temp[sample_i, user_i] = pred_scanpath[sample_i * step, user_i]
                    pred_scanpath = temp
                    test_seq_lengths = lengths_temp
                test_seq_lengths = test_seq_lengths.numpy()

                " Evaluate the produced scanpaths "
                lev, dtw, rec = compute_scanpath_mertics(pred_scanpath, gt_scanpath, image_size, test_seq_lengths)
                re1.append(lev), re2.append(dtw), re3.append(rec)

                if not is_train:
                    " Plot scanpaths "
                    plot_root = './results/scanpaths/'
                    if not os.path.exists(plot_root): os.makedirs(plot_root)
                    plot_scanpaths(pred_scanpath, img_path, test_seq_lengths,
                                   save_path=os.path.join(plot_root + dataset_name))

                    " Normalize (x, y) to range [0, 1] "
                    for user_i in range(test_seq_lengths.shape[0]):
                        length = test_seq_lengths[user_i]
                        if user_i == 0:
                            gt_norm = gt_scanpath[:length, user_i, :] / image_size
                            pred_norm = pred_scanpath[:length, user_i, :] / image_size
                        else:
                            gt_norm = np.vstack((gt_norm, gt_scanpath[:length, user_i, :] / image_size))
                            pred_norm = np.vstack((pred_norm, pred_scanpath[:length, user_i, :] / image_size))

                    " Get the salmap and compute saliency metrics "
                    sal_metrics = get_salmaps(pred_norm, gt_norm, [720, 1440], dataset_name, img_path)
                    sal_metrics_res = np.array(sal_metrics)
                    auc.append(sal_metrics_res[:, 0])
                    nss.append(sal_metrics_res[:, 1])
                    cc.append(sal_metrics_res[:, 2])
                    kld.append(sal_metrics_res[:, 3])

        lev, dtw, rec = np.nanmean(np.array(re1)), np.nanmean(np.array(re2)), np.nanmean(np.array(re3))

        if not is_train:
            auc1, nss1 = np.mean(np.array(auc), axis=0), np.mean(np.array(nss), axis=0)
            cc1, kld1 = np.mean(np.array(cc), axis=0), np.mean(np.array(kld), axis=0)
            return [lev, dtw, rec, auc1, nss1, cc1, kld1]

        else:
            return [lev, dtw, rec]

    def log_performances(dataset_name: str, res):
        log_root = results_log
        if res is None:
            res = '\n'
        elif isinstance(res, int):
            res = dataset_name + str(res) + '\n'
        else:
            _res = dataset_name + '\t' + \
                   'LEV=' + str(round(res[0], 3)) + '\t' + \
                   'DTW=' + str(round(res[1], 3)) + '\t' + \
                   'REC=' + str(round(res[2], 3)) + '\n'
            res = _res
        if not os.path.exists(log_root):
            f = open(log_root, 'w', encoding='utf8')
            f.write(res)
            f.close()
        else:
            f = open(log_root, 'a', encoding='utf8')
            f.write(res)
            f.close()

    predictive = Predictive(Model.model, num_samples=num_samples)
    lev, dtw, rec = np.zeros(4), np.zeros(4), np.zeros(4)
    dic = {'AOI': 0, 'JUFE': 1, 'Salient360': 2, 'Sitzmann': 3}

    log_performances('------ Epoch:', epoch)

    if sitzmann is not None:
        for i in range(nums_predition_loop):
            _res = pred_loop(sitzmann, 'Sitzmann')
            if i == 0:
                res = np.array(_res)
            else:
                res += np.array(_res)
        res = res / nums_predition_loop
        lev[3], dtw[3], rec[3] = res[0], res[1], res[2]
        log_performances('Sitzmann', res)

    if jxufe is not None:
        for i in range(nums_predition_loop):
            _res = pred_loop(jxufe, 'Jxufe')
            if i == 0:
                res = np.array(_res)
            else:
                res += np.array(_res)
        res = res / nums_predition_loop
        lev[1], dtw[1], rec[1] = res[0], res[1], res[2]
        log_performances('Jxufe', res)

    if not is_train:
        if salient360 is not None:
            for i in range(nums_predition_loop):
                _res = pred_loop(salient360, 'Salient360')
                if i == 0:
                    res = np.array(_res)
                else:
                    res += np.array(_res)
            res = res / nums_predition_loop
            lev[2], dtw[2], rec[2] = res[0], res[1], res[2]
            log_performances('Salient360', res)

        if aoi is not None:
            for i in range(nums_predition_loop):
                _res = pred_loop(aoi, 'AOI')
                if i == 0:
                    res = np.array(_res)
                else:
                    res += np.array(_res)
            res = res / nums_predition_loop
            lev[0], dtw[0], rec[0] = res[0], res[1], res[2]
            log_performances('AOI', res)

        if jxufe is not None:
            for i in range(nums_predition_loop):
                _res = pred_loop(jxufe, 'Jxufe')
                if i == 0:
                    res = np.array(_res)
                else:
                    res += np.array(_res)
            res = res / nums_predition_loop
            lev[1], dtw[1], rec[1] = res[0], res[1], res[2]
            log_performances('Jxufe', res)

    log_performances('\n', None)

    i = dic[train_set]
    return lev[i], dtw[i], rec[i]
lsztzp commented 9 months ago

Here is the evaluation.py

from pyro.infer import Predictive
from config import image_size
import pyro.contrib.examples.polyphonic_data_loader as poly
from suppor_lib import *
from tqdm import tqdm
from scanpath_metrics import *

def eval(Model=None,
         aoi=None, jxufe=None, salient360=None, sitzmann=None,
         sample_rate=1, with_cnn=False, train_set=None,
         nums_predition_loop=1, num_samples=10, is_train=False, random_st=False, epoch=None, results_log=None):

    def create_random_start_points(num_points):
        y, x = [], []
        for i in range(num_points):
            while True:
                temp = np.random.normal(loc=0, scale=0.2)
                if (temp <= 1) and (temp >= -1):
                    y.append(temp)
                    break
            x.append(np.random.uniform(-1, 1))
        cords = np.vstack((np.array(y) * 90, np.array(x) * 180)).swapaxes(0, 1)
        cords = latlontoxyz(torch.from_numpy(cords))
        return cords

    def summary(samples, gt):
        """
        Transform 3D (x,y,z) to 2D (x, y) ranging in [128, 256]
        """
        obs = None
        obs_gt = None
        for index in range(int(len(samples) / 2)):
            name = 'obs_x_' + str(index + 1)
            temp = samples[name].reshape([-1, 3])
            its_sum = torch.sqrt(temp[:, 0] ** 2 + temp[:, 1] ** 2 + temp[:, 2] ** 2)
            temp = temp / torch.unsqueeze(its_sum, 1)
            if obs is not None:
                obs = torch.cat((obs, torch.unsqueeze(sphere2plane(xyztolatlon(temp), [128, 256]), dim=0)), dim=0)
            else:
                obs = torch.unsqueeze(sphere2plane(xyztolatlon(temp), [128, 256]), dim=0)

        for index in range(gt.shape[1]):
            if obs_gt is not None:
                obs_gt = torch.cat(
                    (obs_gt, torch.unsqueeze(sphere2plane(xyztolatlon(gt[:, index, :]), [128, 256]), dim=0)),
                    dim=0)
            else:
                obs_gt = torch.unsqueeze(sphere2plane(xyztolatlon(gt[:, index, :]), [128, 256]), dim=0)

        return obs, obs_gt

    def pred_loop(data, dataset_name: str):
        """"""
        " Get required information "
        image_index = data["image_index"]
        image_paths = data['img_paths']
        is_aoi = (dataset_name == 'AOI')
        num_loop = image_index.max() + 1

        " Initialize results "
        re1, re2, re3 = [], [], []
        auc, auc_s, nss, cc, sim, kld = [], [], [], [], [], []

        for image_i in tqdm(range(num_loop)):
            " Find related index  "
            test_seq_index = torch.where(image_index == image_i)[0]

            " We use sample_rate to speed up when training "
            sample_index = test_seq_index[0:-1:sample_rate]

            " Get the lengths we need to predict "
            test_seq_lengths = data["sequence_lengths"][sample_index]
            if is_aoi:
                lengths_temp = test_seq_lengths
                test_seq_lengths = (torch.ones_like(lengths_temp) * 22).int()

            " Get the GT scanpaths (we only use the x_1 for initializing state) "
            test_data_sequences = data["sequences"][sample_index]

            " Get the transformed images and original images (the latter are used for drawing saliency map) "
            test_image = data["images"][sample_index]
            img_path = image_paths[image_i]

            " Setting random starting point when test model "
            if random_st:
                random_start_points = create_random_start_points(test_data_sequences.shape[0])
                test_data_sequences[:, 0] = random_start_points

            " Compute the mask "
            test_mask = poly.get_mini_batch_mask(test_data_sequences, test_seq_lengths)

            " Feeding to GPU "
            test_batch = test_data_sequences.cuda()
            test_batch_mask = test_mask.cuda()
            mini_batch_images = test_image.cuda()

            " Prediction "
            with torch.no_grad():
                if with_cnn:
                    image_data = mini_batch_images
                else:
                    image_data = None
                samples = predictive(scanpaths=test_batch,
                                     scanpaths_reversed=None,
                                     mask=test_batch_mask,
                                     scanpath_lengths=None,
                                     images=image_data,
                                     predict=True)

                " Process the scanpaths: 3D coords (x, y, z) -> 2D plane (x, y) "
                pred_summary, gt_summary = summary(samples, test_batch)
                pred_scanpath = pred_summary.cpu().numpy()
                gt_scanpath = gt_summary.cpu().numpy()

                """  For AOI database, we uniformly sampling gaze points from 22 seconds scanpaths,
                 ensuring lengths of produced scanpaths = lengths of GT scanpaths """
                if is_aoi:
                    temp = np.zeros_like(gt_scanpath)
                    for user_i in range(gt_scanpath.shape[1]):
                        current_length = lengths_temp[user_i]
                        step = int(22 / current_length)
                        for sample_i in range(current_length):
                            temp[sample_i, user_i] = pred_scanpath[sample_i * step, user_i]
                    pred_scanpath = temp
                    test_seq_lengths = lengths_temp
                test_seq_lengths = test_seq_lengths.numpy()

                " Evaluate the produced scanpaths "
                lev, dtw, rec = compute_scanpath_mertics(pred_scanpath, gt_scanpath, image_size, test_seq_lengths)
                re1.append(lev), re2.append(dtw), re3.append(rec)

                if not is_train:
                    " Plot scanpaths "
                    plot_root = './results/scanpaths/'
                    if not os.path.exists(plot_root): os.makedirs(plot_root)
                    plot_scanpaths(pred_scanpath, img_path, test_seq_lengths,
                                   save_path=os.path.join(plot_root + dataset_name))

                    " Normalize (x, y) to range [0, 1] "
                    for user_i in range(test_seq_lengths.shape[0]):
                        length = test_seq_lengths[user_i]
                        if user_i == 0:
                            gt_norm = gt_scanpath[:length, user_i, :] / image_size
                            pred_norm = pred_scanpath[:length, user_i, :] / image_size
                        else:
                            gt_norm = np.vstack((gt_norm, gt_scanpath[:length, user_i, :] / image_size))
                            pred_norm = np.vstack((pred_norm, pred_scanpath[:length, user_i, :] / image_size))

                    " Get the salmap and compute saliency metrics "
                    sal_metrics = get_salmaps(pred_norm, gt_norm, [720, 1440], dataset_name, img_path)
                    sal_metrics_res = np.array(sal_metrics)
                    auc.append(sal_metrics_res[:, 0])
                    nss.append(sal_metrics_res[:, 1])
                    cc.append(sal_metrics_res[:, 2])
                    kld.append(sal_metrics_res[:, 3])

        lev, dtw, rec = np.nanmean(np.array(re1)), np.nanmean(np.array(re2)), np.nanmean(np.array(re3))

        if not is_train:
            auc1, nss1 = np.mean(np.array(auc), axis=0), np.mean(np.array(nss), axis=0)
            cc1, kld1 = np.mean(np.array(cc), axis=0), np.mean(np.array(kld), axis=0)
            return [lev, dtw, rec, auc1, nss1, cc1, kld1]

        else:
            return [lev, dtw, rec]

    def log_performances(dataset_name: str, res):
        log_root = results_log
        if res is None:
            res = '\n'
        elif isinstance(res, int):
            res = dataset_name + str(res) + '\n'
        else:
            _res = dataset_name + '\t' + \
                   'LEV=' + str(round(res[0], 3)) + '\t' + \
                   'DTW=' + str(round(res[1], 3)) + '\t' + \
                   'REC=' + str(round(res[2], 3)) + '\n'
            res = _res
        if not os.path.exists(log_root):
            f = open(log_root, 'w', encoding='utf8')
            f.write(res)
            f.close()
        else:
            f = open(log_root, 'a', encoding='utf8')
            f.write(res)
            f.close()

    predictive = Predictive(Model.model, num_samples=num_samples)
    lev, dtw, rec = np.zeros(4), np.zeros(4), np.zeros(4)
    dic = {'AOI': 0, 'JUFE': 1, 'Salient360': 2, 'Sitzmann': 3}

    log_performances('------ Epoch:', epoch)

    if sitzmann is not None:
        for i in range(nums_predition_loop):
            _res = pred_loop(sitzmann, 'Sitzmann')
            if i == 0:
                res = np.array(_res)
            else:
                res += np.array(_res)
        res = res / nums_predition_loop
        lev[3], dtw[3], rec[3] = res[0], res[1], res[2]
        log_performances('Sitzmann', res)

    if jxufe is not None:
        for i in range(nums_predition_loop):
            _res = pred_loop(jxufe, 'Jxufe')
            if i == 0:
                res = np.array(_res)
            else:
                res += np.array(_res)
        res = res / nums_predition_loop
        lev[1], dtw[1], rec[1] = res[0], res[1], res[2]
        log_performances('Jxufe', res)

    if not is_train:
        if salient360 is not None:
            for i in range(nums_predition_loop):
                _res = pred_loop(salient360, 'Salient360')
                if i == 0:
                    res = np.array(_res)
                else:
                    res += np.array(_res)
            res = res / nums_predition_loop
            lev[2], dtw[2], rec[2] = res[0], res[1], res[2]
            log_performances('Salient360', res)

        if aoi is not None:
            for i in range(nums_predition_loop):
                _res = pred_loop(aoi, 'AOI')
                if i == 0:
                    res = np.array(_res)
                else:
                    res += np.array(_res)
            res = res / nums_predition_loop
            lev[0], dtw[0], rec[0] = res[0], res[1], res[2]
            log_performances('AOI', res)

        if jxufe is not None:
            for i in range(nums_predition_loop):
                _res = pred_loop(jxufe, 'Jxufe')
                if i == 0:
                    res = np.array(_res)
                else:
                    res += np.array(_res)
            res = res / nums_predition_loop
            lev[1], dtw[1], rec[1] = res[0], res[1], res[2]
            log_performances('Jxufe', res)

    log_performances('\n', None)

    i = dic[train_set]
    return lev[i], dtw[i], rec[i]

ok, thanks again for your help

lsztzp commented 9 months ago

Hello author, thank you very much for your help earlier. I found that when processing the three data sets other than Sitzmann, the length of their scanning paths was very inconsistent, and the time was not continuous, which could not be simply processed at 1HZ. I would like to know some more detailed information when processing the other three data sets, and know the length of the generated path when verifying their data. I'd better hope that the author can provide the corresponding data processing code, thank you!

Yes, some of them are complex to process to 1 hz strictly, so actually, what we done was to \approx 1 hz. Sure, I could provide the codes, but I had not cleaned them, so it might a bit difficult to read. The codes will be pasted here in next few days. I have been busy recently, please remind me if I forget.

hello, when you are free, can you provide this function code"compute_scanpath_mertics(pred_scanpath, gt_scanpath, image_size, test_seq_lengths)" in "scanpath_metrics.py"。it would help me a lot, and thank you again.

lsztzp commented 8 months ago

Hello author, thank you very much for your help earlier. I found that when processing the three data sets other than Sitzmann, the length of their scanning paths was very inconsistent, and the time was not continuous, which could not be simply processed at 1HZ. I would like to know some more detailed information when processing the other three data sets, and know the length of the generated path when verifying their data. I'd better hope that the author can provide the corresponding data processing code, thank you!

Yes, some of them are complex to process to 1 hz strictly, so actually, what we done was to \approx 1 hz. Sure, I could provide the codes, but I had not cleaned them, so it might a bit difficult to read. The codes will be pasted here in next few days. I have been busy recently, please remind me if I forget.

Hello, recently I tried to reproduce these methods, and I got very similar data to the paper. But I'm still a little confused, and I want to know what method you used to generate the salient images when replicating the CLE method. Thank you so much!

xiangjieSui commented 8 months ago

Hello author, thank you very much for your help earlier. I found that when processing the three data sets other than Sitzmann, the length of their scanning paths was very inconsistent, and the time was not continuous, which could not be simply processed at 1HZ. I would like to know some more detailed information when processing the other three data sets, and know the length of the generated path when verifying their data. I'd better hope that the author can provide the corresponding data processing code, thank you!

Yes, some of them are complex to process to 1 hz strictly, so actually, what we done was to \approx 1 hz. Sure, I could provide the codes, but I had not cleaned them, so it might a bit difficult to read. The codes will be pasted here in next few days. I have been busy recently, please remind me if I forget.

Hello, recently I tried to reproduce these methods, and I got very similar data to the paper. But I'm still a little confused, and I want to know what method you used to generate the salient images when replicating the CLE method. Thank you so much!

Thanks for follow-up, I used the following method to generate salmap for CLE: Static and Space-time Visual Saliency Detection by Self-Resemblance, Hae Jong Seo and Peyman Milanfar.

Since I rarely check my github email, you can contact me: xjsui@foxmail.com, to facilitate a timely reply. Thanks!