haoliuhl / ringattention

Transformers with Arbitrarily Large Context
Apache License 2.0
630 stars 50 forks source link

fine-tuning model mismatch - KeyError #13

Closed chenwuperth closed 8 months ago

chenwuperth commented 8 months ago

Thanks for providing the repo. I have a question regarding fine-tuning as mentioned in the paper (Section 5,4)

As the README.md suggested, --load_checkpoint='params::/path/output' is used for fine-tuning based on HF model converted from the hf2jax.py script. However, when scan_layers=True, it appears that the layer name (keys) from path/output do not match those in shard_fns during loading the HF weights. For example, ('transformer', 'h', 'scan_decoder', 'attention', 'wq', 'kernel') from shard_fns does not match the key'transformer', 'h', '0', 'attention', 'wq', 'kernel' unpacked from /path/output.

This eventually raises the KeyError: ('transformer', 'h', '0', 'attention', 'wq', 'kernel') exception during load_checkpoint.

have I missed anything for fine-tuning configuration or is there a workaround this?

Thank you!