google-research / l2p

Learning to Prompt (L2P) for Continual Learning @ CVPR22 and DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning @ ECCV22
https://arxiv.org/pdf/2112.08654.pdf
Apache License 2.0
416 stars 42 forks source link

Transfer prompt parameter during training process. #38

Open kimsekeun opened 1 year ago

kimsekeun commented 1 year ago

Hi authors, thanks for great works. I have question in training process as prompt pool is selected exclusively based on task id. I guess the prompt pool is shared across the tasks. why we need to transfer prompt to new index? I think new index is trained on next task automatically.

Details of code part:

Transfer previous learned prompt params to the new prompt

if config.prompt_pool and config.prompt_pool_param.shared_prompt_pool: if task_id > 0: prev_start = (task_id - 1) config.prompt_pool_param.top_k prev_end = task_id config.prompt_pool_param.top_k cur_start = prev_end cur_end = (task_id + 1) * config.prompt_pool_param.top_k if (prev_end > config.prompt_pool_param.pool_size) or ( cur_end > config.prompt_pool_param.pool_size): pass else: param_dict = state.optimizer.target prompt_pool_para = param_dict["prompt_pool"]["prompt"] if config.use_prefix_tune_for_e_prompt: prompt_pool_para = prompt_pool_para.at[:, :, cur_start:cur_end].set( prompt_pool_para[:, :, prev_start:prev_end]) else: prompt_pool_para = prompt_pool_para.at[:, cur_start:cur_end].set( prompt_pool_para[:, prev_start:prev_end]) paramdict, = utils.replace_prompt_pool(param_dict, prompt_pool_para) state = utils.state_with_new_param(state, param_dict)