tianyang-x / Mixture-of-Domain-Adapters

Codebase for ACL 2023 paper "Mixture-of-Domain-Adapters: Decoupling and Injecting Domain Knowledge to Pre-trained Language Models' Memories"
MIT License
46 stars 1 forks source link

stage2似乎没加载上stage1的权重,state_dict里的key不一样了 #5

Closed JachinLin2022 closed 10 months ago

JachinLin2022 commented 1 year ago

def save_knowledge_adapter(self, f): if hasattr(self, "roberta"): encoder = self.roberta.encoder else: encoder = self.encoder dict = {} for i, layer in enumerate(encoder.layer): if layer.output.kas is not None: for j, adapter in enumerate(layer.output.kas): dict[str(i) + '-' + str(j) + '-adapter'] = adapter.state_dict() if layer.output.gating is not None: dict[str(i) + '-gating'] = layer.output.gating.state_dict() torch.save(dict, f)

def load_knowledge_adapter(self, f_list, one_f=None):
    if hasattr(self, "roberta"):
        encoder = self.roberta.encoder
    else:
        encoder = self.encoder
    for n_ka, f in enumerate(f_list):
        checkpoint = torch.load(f, map_location='cuda:0')
        for i, layer in enumerate(encoder.layer):
            if str(i) + '-adapter' in checkpoint:
                layer.output.kas[n_ka].load_state_dict(checkpoint[str(i) + '-adapter'])
            # if str(i) + '-attn' in checkpoint:
                # layer.output.attn.load_state_dict(checkpoint[str(i) + '-attn'])
            if str(i) + '-gating' in checkpoint and len(f_list) == 1 and layer.output.gating is not None:
                layer.output.gating.load_state_dict(checkpoint[str(i) + '-gating'])
    if one_f is not None:
        checkpoint = torch.load(one_f, map_location='cuda:0')
        for i, layer in enumerate(encoder.layer):
            for n_ka in range(self.ka_list):
                if str(i) + '-' + str(n_ka) + '-adapter' in checkpoint:
                    layer.output.kas[n_ka].load_state_dict(
                        checkpoint[str(i) + '-' + str(n_ka) + '-adapter'])
            if str(i) + '-gating' in checkpoint:
                layer.output.gating.load_state_dict(
                    checkpoint[str(i) + '-gating'])
tianyang-x commented 1 year ago

你好 @JachinLin2022 ,

多谢反馈。在我们早期的实验中,我们用的是i-adapter这一格式,对应的是load_knowledge_adapter函数。在后期,我们修改了save_knowledge_adapter函数的格式,并且对load_knowledge_adapter增加了one_f参数适应新的格式。我们实验时两种格式都用过,但是由于代码整理的问题,one_f参数没有被实际使用。我们会尽快修改这一问题。

tianyang-x commented 1 year ago

你好 @JachinLin2022 ,

刚才的commit应该已经修复了这一问题。