Hi authors, thanks for great works.
I have question in training process as prompt pool is selected exclusively based on task id.
I guess the prompt pool is shared across the tasks.
why we need to transfer prompt to new index? I think new index is trained on next task automatically.
Details of code part:
Transfer previous learned prompt params to the new prompt
if config.prompt_pool and config.prompt_pool_param.shared_prompt_pool:
if task_id > 0:
prev_start = (task_id - 1) config.prompt_pool_param.top_k
prev_end = task_id config.prompt_pool_param.top_k
cur_start = prev_end
cur_end = (task_id + 1) * config.prompt_pool_param.top_k
if (prev_end > config.prompt_pool_param.pool_size) or (
cur_end > config.prompt_pool_param.pool_size):
pass
else:
param_dict = state.optimizer.target
prompt_pool_para = param_dict["prompt_pool"]["prompt"]
if config.use_prefix_tune_for_e_prompt:
prompt_pool_para = prompt_pool_para.at[:, :, cur_start:cur_end].set(
prompt_pool_para[:, :, prev_start:prev_end])
else:
prompt_pool_para = prompt_pool_para.at[:, cur_start:cur_end].set(
prompt_pool_para[:, prev_start:prev_end])
paramdict, = utils.replace_prompt_pool(param_dict, prompt_pool_para)
state = utils.state_with_new_param(state, param_dict)
Hi authors, thanks for great works. I have question in training process as prompt pool is selected exclusively based on task id. I guess the prompt pool is shared across the tasks. why we need to transfer prompt to new index? I think new index is trained on next task automatically.
Details of code part:
Transfer previous learned prompt params to the new prompt
if config.prompt_pool and config.prompt_pool_param.shared_prompt_pool: if task_id > 0: prev_start = (task_id - 1) config.prompt_pool_param.top_k prev_end = task_id config.prompt_pool_param.top_k cur_start = prev_end cur_end = (task_id + 1) * config.prompt_pool_param.top_k if (prev_end > config.prompt_pool_param.pool_size) or ( cur_end > config.prompt_pool_param.pool_size): pass else: param_dict = state.optimizer.target prompt_pool_para = param_dict["prompt_pool"]["prompt"] if config.use_prefix_tune_for_e_prompt: prompt_pool_para = prompt_pool_para.at[:, :, cur_start:cur_end].set( prompt_pool_para[:, :, prev_start:prev_end]) else: prompt_pool_para = prompt_pool_para.at[:, cur_start:cur_end].set( prompt_pool_para[:, prev_start:prev_end]) paramdict, = utils.replace_prompt_pool(param_dict, prompt_pool_para) state = utils.state_with_new_param(state, param_dict)