BlackSamorez / tensor_parallel

Automatically split your PyTorch models on multiple GPUs for training & inference
MIT License
629 stars 39 forks source link

How to use trained models? #48

Closed Den4ikAI closed 1 year ago

Den4ikAI commented 1 year ago

Hi, I trained BLOOM using this library, but how do I load it for inference? Architecture BloomForCasualLM changed to TensorParalelForPretraining @BlackSamorez

BlackSamorez commented 1 year ago

Sadly, right now you can't save tensor_parallel model as a normal model checkpoint. I'm currently working on a PR that would bring automatic shards gathering when calling tensor_parallel state_dict. That would make it that tensor_parallel checkpoints are indistinguishable from normal ones. It'll be ready in a few days!

BlackSamorez commented 1 year ago

Starting with tensor_parallel v1.1.0 you can save tensor parallel models using conventional methods like state_dict(). If you already have a state dict of a model from previous versions you could perform those actions to convert it into a normal state dict:

0) Update to transformers>=1.1.0 1) Create a tensor parallel model. 2) Load the sharded state dict 3) Call call .state_dict() to construct a normal state dict

Den4ikAI commented 1 year ago

Thanks