Li-ZK / CLDA-2022

Confident Learning-Based Domain Adaptation for Hyperspectral Image Classification(DOI: 10.1109/TGRS.2022.3166817)
25 stars 4 forks source link

关于伪标签的获取 #4

Open lzzlhh opened 4 weeks ago

lzzlhh commented 4 weeks ago

李老师,您好 拜读了您的文章,我有个问题想请教一下 您的代码中是if (ep >= train_num and ep < num_epoch) and ep % 20 == 0: 才会获取fake_label并clean data,用confident 目标域的data和fake_label去训练。但是在训练过程中您的代码中
if ep >= train_epoch: (data_s, label_s), (data_t, fake_label_t) = data
fake_label_t = Variable(fake_label_t).cuda()
您的train_num =train_epoch ,您代码中的设置都为20 如果ep=21 那么这时候就不会获取fake_label,而您 (data_s, label_s), (data_t, fake_label_t) = data 这行代码中获取的fake_label_t 不就是目标域的真实标签吗? 可能理解的不对,请您赐教,万分感谢

fang-zhuoqun commented 4 weeks ago

你好,fake_label_t 来自于 train() 函数的参数 data_loader_t,在 CLDA_UP2PC.py 主程序中的361行,我们用以下代码调用 train() 函数: clean_datas, clean_labels, class_weights = clean_sampling_epoch(fake_label, probs) target_datasets = TensorDataset(torch.tensor(clean_datas), torch.tensor(clean_labels)) train_loader_t = DataLoader(target_datasets, batch_size=BATCH_SIZE, shuffle=True, num_workers=0,drop_last=True) train(ep, train_loader_s, train_loader_t, train_num, class_weights) 可见,这里的实参 train_loader_t 就是清洗后的伪标签,而非真实标签。因此,你所说的变量 fake_label_t 也全都来自于伪标签。

lzzlhh commented 4 weeks ago

非常感谢您的回复,以下这三行代码的前提是 if (ep >= train_num and ep < num_epoch) and ep % 20 == 0: clean_datas, clean_labels, class_weights = clean_sampling_epoch(fake_label, probs) target_datasets = TensorDataset(torch.tensor(clean_datas), torch.tensor(clean_labels)) train_loader_t = DataLoader(target_datasets, batch_size=BATCH_SIZE, shuffle=True, num_workers=0,drop_last=True) 但是如果ep%20≠0时,此时就不会生成清洗后的train_loader_t,而是使用之前的未经过清洗的train_loader_t,这时 train() 函数的参数 data_loader_t使用的就是之前的未经过清洗的train_loader_t,但是后续代码中依然使用 if ep >= train_epoch: (data_s, label_s), (data_t, fake_label_t) = data ,这里的fake_label_t不就是来源于未经过清洗的train_loader_t,他这不就是真实标签吗?希望您能够不吝赐教