Open AleHD opened 5 months ago
Hello,
Thank you for your PR.
As I understand it, you are verifying the checkpoint version. If it is 1.3, a different pattern will be used to retrieve the shards; otherwise, the existing pattern will be used. I suspect that this approach may not resolve all the issues, given that the checkpoint version 1.3 was only added to the main branch 4 days ago (see commit), whereas the bug was introduced a month and a half ago in this commit. Additionally, checking for version 1.2 seems problematic as it might disrupt the loading logic for all checkpoints before this commit.
Timeline: Checkpoint version=1.2 -> bug introduced -> Checkpoint version=1.3
We have already merged a quick fix that utilized a script to modify the file name: PR #151. Could you build upon this solution? Perhaps making it automatic when this pattern is detected, instead of requiring users to run the script manually.
I really appreciate your contribution.
Thanks for the feedback. I implemented the suggested changes and tested it locally. Seems to be working ok. Let me know if you have any other comments.
This pull request introduces three main additions:
.safetensors
suffix incorrectly placed.serialize.weights.load_weights
toserialize.utils.get_path
thanks to the addition ofreturn_all_matches
keyword.