xlwangDev / HC-Net

[NeurIPS 2023] Fine-Grained Cross-View Geo-Localization Using a Correlation-Aware Homography Estimator
50 stars 5 forks source link

Code for correct label when testing CCVPE? #4

Closed Sugar1y closed 10 months ago

Sugar1y commented 11 months ago

Can you provide the full dataset.py code when testing CCVPE with the correct labels again?

guanfang12 commented 11 months ago
# ---------------------------------------------------------------------------------
# VIGOR

class VIGORDataset(Dataset):
    def __init__(self, root, label_root = 'splits_new', split='samearea', train=True, transform=None, pos_only=True, ori_noise=180):
        root = os.path.join('/***/***/','VIGOR/')
        self.root = root
        label_root = 'splits__corrected'
        self.label_root = 'splits__corrected'
        self.split = split
        self.train = train
        self.pos_only = pos_only
        self.ori_noise = ori_noise

        if transform != None:
            self.grdimage_transform = transform[0]
            self.satimage_transform = transform[1]

        if self.split == 'samearea':
            self.city_list = ['NewYork', 'Seattle', 'SanFrancisco', 'Chicago']
        elif self.split == 'crossarea':
            if self.train:
                self.city_list = ['NewYork', 'Seattle']
            else:
                self.city_list = ['SanFrancisco', 'Chicago']

        # load sat list
        self.sat_list = []
        self.sat_index_dict = {}

        idx = 0
        for city in self.city_list:
            sat_list_fname = os.path.join(self.root, label_root, city, 'satellite_list.txt')
            with open(sat_list_fname, 'r') as file:
                for line in file.readlines():
                    self.sat_list.append(os.path.join(self.root, city, 'satellite', line.replace('\n', '')))
                    self.sat_index_dict[line.replace('\n', '')] = idx
                    idx += 1
            print('InputData::__init__: load', sat_list_fname, idx)
        self.sat_list = np.array(self.sat_list)
        self.sat_data_size = len(self.sat_list)
        print('Sat loaded, data size:{}'.format(self.sat_data_size))

        # load grd list  
        self.grd_list = []
        self.label = []
        self.sat_cover_dict = {}
        self.delta = []
        idx = 0
        for city in self.city_list:
            # load grd panorama list
            if self.split == 'samearea':
                if self.train:
                    label_fname = os.path.join(self.root, self.label_root, city, 'same_area_balanced_train.txt')
                else:
                    label_fname = os.path.join(self.root, label_root, city, 'same_area_balanced_test.txt')
            elif self.split == 'crossarea':
                label_fname = os.path.join(self.root, self.label_root, city, 'pano_label_balanced.txt')

            with open(label_fname, 'r') as file:
                for line in file.readlines():
                    data = np.array(line.split(' '))
                    label = []
                    for i in [1, 4, 7, 10]:
                        label.append(self.sat_index_dict[data[i]])
                    label = np.array(label).astype(int)
                    delta = np.array([data[2:4], data[5:7], data[8:10], data[11:13]]).astype(float)
                    self.grd_list.append(os.path.join(self.root, city, 'panorama', data[0]))
                    self.label.append(label)
                    self.delta.append(delta)
                    if not label[0] in self.sat_cover_dict:
                        self.sat_cover_dict[label[0]] = [idx]
                    else:
                        self.sat_cover_dict[label[0]].append(idx)
                    idx += 1
            print('InputData::__init__: load ', label_fname, idx)
        self.data_size = len(self.grd_list)
        print('Grd loaded, data size:{}'.format(self.data_size))
        self.label = np.array(self.label)
        self.delta = np.array(self.delta)

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):        
        # full ground panorama
        try:
            grd = PIL.Image.open(os.path.join(self.grd_list[idx]))
            pano_gps = np.array(self.grd_list[idx][:-5].split(',')[-2:]).astype(float)
            grd = grd.convert('RGB')
        except:
            print('unreadable image')
            grd = PIL.Image.new('RGB', (320, 640)) # if the image is unreadable, use a blank image
        grd = self.grdimage_transform(grd)

        # generate a random rotation 
        if self.ori_noise >= 180:
            rotation = np.random.uniform(low=0.0, high=1.0) # 
        else:
            rotation_range = self.ori_noise / 360
            rotation = np.random.uniform(low=-rotation_range, high=rotation_range)
            if rotation <0:
                rotation = 1+rotation
        grd = torch.roll(grd, (torch.round(torch.as_tensor(rotation)*grd.size()[2]).int()).item(), dims=2)

        orientation_angle = rotation * 360 # 0 means heading North, counter-clockwise increasing

        # satellite
        if self.pos_only: # load positives only
            pos_index = 0
            sat = PIL.Image.open(os.path.join(self.sat_list[self.label[idx][pos_index]]))
            sat_gps = np.array(self.sat_list[self.label[idx][pos_index]][:-4].split('_')[-2:]).astype(float)
            # [row_offset, col_offset] = self.delta[idx, pos_index] # delta = [delta_lat, delta_lon]
        else: # load positives and semi-positives
            col_offset = 320 
            row_offset = 320
            while (np.abs(col_offset)>=320 or np.abs(row_offset)>=320): # do not use the semi-positives where GT location is outside the image
                pos_index = random.randint(0,3)
                sat = PIL.Image.open(os.path.join(self.sat_list[self.label[idx][pos_index]]))
                [row_offset, col_offset] = self.delta[idx, pos_index] # delta = [delta_lat, delta_lon]

        pano_gps = torch.from_numpy(pano_gps).unsqueeze(0) # [batch, 2]
        sat_gps = torch.from_numpy(sat_gps).unsqueeze(0)     
        zoom = 20
        y = get_pixel_tensor(sat_gps[:,0], sat_gps[:,1], pano_gps[:,0],pano_gps[:,1], zoom) 
        col_offset_, row_offset_ = y[0], y[1]

        sat = sat.convert('RGB')
        width_raw, height_raw = sat.size

        sat = self.satimage_transform(sat)
        _, height, width = sat.size()

        col_offset, row_offset = width_raw/2 -col_offset_.item(), row_offset_.item() - height_raw/2 # wxl add

        row_offset = np.round(row_offset/height_raw*height)
        col_offset = np.round(col_offset/width_raw*width)

        # groundtruth location on the aerial image       
        # Gaussian GT        
        gt = np.zeros([1, height, width], dtype=np.float32)
        gt_with_ori = np.zeros([20, height, width], dtype=np.float32)
        x, y = np.meshgrid(np.linspace(-width/2+col_offset,width/2+col_offset,width), np.linspace(-height/2-row_offset,height/2-row_offset,height))
        d = np.sqrt(x*x+y*y)
        sigma, mu = 4, 0.0
        gt[0, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) )
        gt = torch.tensor(gt)

        # find the ground truth orientation index, we use 20 orientation bins, and each bin is 18 degrees
        index = int(orientation_angle // 18)
        ratio = (orientation_angle % 18) / 18
        if index == 0:
            gt_with_ori[0, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) ) * (1-ratio)
            gt_with_ori[19, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) ) * ratio
        else:
            gt_with_ori[20-index, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) ) * (1-ratio)
            gt_with_ori[20-index-1, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) ) * ratio
        gt_with_ori = torch.tensor(gt_with_ori)

        orientation = torch.full([2, height, width], np.cos(orientation_angle * np.pi/180))
        orientation[1,:,:] = np.sin(orientation_angle * np.pi/180)

        if 'NewYork' in self.grd_list[idx]:
            city = 'NewYork'
        elif 'Seattle' in self.grd_list[idx]:
            city = 'Seattle'
        elif 'SanFrancisco' in self.grd_list[idx]:
            city = 'SanFrancisco'
        elif 'Chicago' in self.grd_list[idx]:
            city = 'Chicago'

        return grd, sat, gt, gt_with_ori, orientation, city, orientation_angle