Closed Den4ikAI closed 1 year ago
Sadly, right now you can't save tensor_parallel
model as a normal model checkpoint.
I'm currently working on a PR that would bring automatic shards gathering when calling tensor_parallel
state_dict. That would make it that tensor_parallel
checkpoints are indistinguishable from normal ones.
It'll be ready in a few days!
Starting with tensor_parallel v1.1.0
you can save tensor parallel models using conventional methods like state_dict()
.
If you already have a state dict of a model from previous versions you could perform those actions to convert it into a normal state dict:
0) Update to transformers>=1.1.0
1) Create a tensor parallel model.
2) Load the sharded state dict
3) Call call .state_dict()
to construct a normal state dict
Thanks
Hi, I trained BLOOM using this library, but how do I load it for inference? Architecture BloomForCasualLM changed to TensorParalelForPretraining @BlackSamorez