lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.36k stars 255 forks source link

Error when train CoarseTransformer #111

Closed Ted-developer closed 1 year ago

Ted-developer commented 1 year ago

In CoarseTransformer training script, use soundstream.load('./output/soundstream.pt') to load model which was trained at first with SoundStreamTrainer

I got an error:

│ /home/work/python/lib/python3.8/site-packages/torch/nn/modules/module.py:1604 in │ │ load_state_dict │ │ │ │ 1601 │ │ │ │ │ │ ', '.join('"{}"'.format(k) for k in missing_keys))) │ │ 1602 │ │ │ │ 1603 │ │ if len(error_msgs) > 0: │ │ ❱ 1604 │ │ │ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( │ │ 1605 │ │ │ │ │ │ │ self.class.name, "\n\t".join(error_msgs))) │ │ 1606 │ │ return _IncompatibleKeys(missing_keys, unexpected_keys) │ │ 1607 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ RuntimeError: Error(s) in loading state_dict for SoundStream: Unexpected key(s) in state_dict: "encoder_attn.1.attn.q_scale", "encoder_attn.1.attn.k_scale", "encoder_attn.1.attn.norm.weight", "encoder_attn.1.attn.norm.bias", "encoder_attn.1.attn.to_qkv.weight", "encoder_attn.1.attn.attn_fn.rel_pos.inv_freq", "encoder_attn.1.attn.to_out.weight", "encoder_attn.1.ff.0.weight", "encoder_attn.1.ff.0.bias", "encoder_attn.1.ff.1.weight", "encoder_attn.1.ff.4.weight", "decoder_attn.1.attn.q_scale", "decoder_attn.1.attn.k_scale", "decoder_attn.1.attn.norm.weight", "decoder_attn.1.attn.norm.bias", "decoder_attn.1.attn.to_qkv.weight", "decoder_attn.1.attn.attn_fn.rel_pos.inv_freq", "decoder_attn.1.attn.to_out.weight", "decoder_attn.1.ff.0.weight", "decoder_attn.1.ff.0.bias", "decoder_attn.1.ff.1.weight", "decoder_attn.1.ff.4.weight".

I have see the save() ,load() method, but can't find the problem, please help me, thanks very much

lucidrains commented 1 year ago

@Ted-developer

hey, so what must have happened is that you trained soundstream on an older version (without a new attention stabilizing trick) and then updated to a newer version

could you try updating to the newest version of audiolm-pytorch and then doing soundstream.load('./output/soundstream.pt', strict = False), and then resume training for another epoch and see if you can salvage the old weights

otherwise, you'll have to load the specific version of audiolm-pytorch at which you trained your soundstream. however there's a bug with the coarse transformer in older versions, so i recommend you just train a new soundstream