lorenmt / auto-lambda

The Implementation of "Auto-Lambda: Disentangling Dynamic Task Relationships" [TMLR 2022].
https://shikun.io/projects/auto-lambda
Other
129 stars 17 forks source link

想咨询train_tasks和pri_tasks代码含义 #12

Closed HuangZR-stu closed 1 month ago

HuangZR-stu commented 1 month ago

train_tasks = create_task_flags('all', opt.dataset, with_noise=False)

pri_tasks = create_task_flags(opt.task, opt.dataset, with_noise=False)

def create_task_flags(task, dataset, with_noise=False): """ Record task and its prediction dimension. Noise prediction is only applied in auxiliary learning. """ nyu_tasks = {'seg': 13, 'depth': 1, 'normal': 3} cityscapes_tasks = {'seg': 19, 'part_seg': 10, 'disp': 1}

tasks = {}
if task != 'all':
    if dataset == 'nyuv2':
        tasks[task] = nyu_tasks[task]
    elif dataset == 'cityscapes':
        tasks[task] = cityscapes_tasks[task]
else:
    if dataset == 'nyuv2':
        tasks = nyu_tasks
    elif dataset == 'cityscapes':
        tasks = cityscapes_tasks

if with_noise:
    tasks['noise'] = 1
return tasks

  # define weighting for primary tasks (with binary weights)
  pri_weights = []
  for t in self.train_tasks:
      if t in self.pri_tasks:
          pri_weights += [1.0]
      else:
          pri_weights += [0.0]

在训练时期的代码,为什么要区分train_tasks和pri_tasks,可以分享一下这部分的代码逻辑吗,感谢大哥!

lorenmt commented 1 month ago

Training tasks 是所有训练的任务,primary tasks 是你关注performance的任务。 primary tasks are the subset of training tasks --> if training tasks == primary tasks, we are doing multi-task learning; else we are doing auxiliary learning. 所有不在primary tasks的training tasks 只是为了辅助 primary tasks。

HuangZR-stu commented 1 month ago

嗯嗯,我今天进行了复现,但是发现辅助任务的权重设置为0.0,当时看有点害怕辅助损失一点也没用上。后续跟了几次以meta_weights计算的修正,略微懂了一点。可惜原文没有伪代码可能看得不是很明白。感谢您的回复。