Thanks for providing the repo. I have a question regarding fine-tuning as mentioned in the paper (Section 5,4)
As the README.md suggested, --load_checkpoint='params::/path/output' is used for fine-tuning based on HF model converted from the hf2jax.py script. However, when scan_layers=True, it appears that the layer name (keys) from path/output do not match those in shard_fns during loading the HF weights. For example,
('transformer', 'h', 'scan_decoder', 'attention', 'wq', 'kernel') from shard_fns does not match the key'transformer', 'h', '0', 'attention', 'wq', 'kernel' unpacked from /path/output.
This eventually raises the KeyError: ('transformer', 'h', '0', 'attention', 'wq', 'kernel') exception during load_checkpoint.
have I missed anything for fine-tuning configuration or is there a workaround this?
Thanks for providing the repo. I have a question regarding fine-tuning as mentioned in the paper (Section 5,4)
As the README.md suggested,
--load_checkpoint='params::/path/output'
is used for fine-tuning based on HF model converted from thehf2jax.py
script. However, whenscan_layers=True
, it appears that the layer name (keys) frompath/output
do not match those inshard_fns
during loading the HF weights. For example,('transformer', 'h', 'scan_decoder', 'attention', 'wq', 'kernel')
fromshard_fns
does not match the key'transformer', 'h', '0', 'attention', 'wq', 'kernel'
unpacked from/path/output
.This eventually raises the
KeyError: ('transformer', 'h', '0', 'attention', 'wq', 'kernel')
exception duringload_checkpoint
.have I missed anything for fine-tuning configuration or is there a workaround this?
Thank you!