Open bottiger1 opened 2 years ago
Hello!
To load the custom pretrained model you would need to save both the config and the state dict in the checkpoint e.g.
torch.save({"config": custom_config, "state_dict": custom_model_state_dict})
. Make sure you only save the state_dict and not the whole PL checkpoint.
I just went down this path if pain and suffering, I want to post this for anyone else who wants to use this AI to train data for their own purposes. It was a really painful experience that let me learn how detoxify works from a tensor level lol.
trainer.fit(model, data_loader, valid_data_loader)
torch.save({"config": model.config, "state_dict": model.state_dict()},"model.pt")
my original idea after seeing this post was just to put this save line after the trainfer.fit. The issue i have been running to is that you can't use model.state_dict() because everything in the state dictonary is prefixed with model. i.g model.bert.encoder.layer.8.output.LayerNorm.weight needs to be converted to bert.encoder.layer.8.output.LayerNorm.weight. after doing all of the translations of every element in the state dictonary i could sucessfully run the checkpoint method in detoxify.
you need to add this to the bottom of train.py
trainer.fit(model, data_loader, valid_data_loader)
statedict = {}
for param_tensor in model.state_dict():
if "model.bert." in param_tensor:
newname = param_tensor.replace("model.","")
statedict[newname] = model.state_dict()[param_tensor]
statedict["classifier.weight"] = model.state_dict()["model.classifier.weight"]
statedict["classifier.bias"] = model.state_dict()["model.classifier.bias"]
torch.save({"config": model.config, "state_dict": statedict},"model.pt")
then you can just import your model using detoxify like below
ai = Detoxify(checkpoint="model.pt")
also for other models like Robert or albert you just need to replace the bert in the if statement above.
Hello I want to train the network on my own samples but I'm finding it quite difficult.
Right now I edited Toxic_comment_classification_BERT.json to point to my own training and test csv. Then I have to edit train.py to manually save the model object inside ToxicClassifier at the end of the training.
Then I have load the file manually, instantiate the normal instance of detoxify, and then replace the internal model object with the saved version to get it to work.
If I try to load a checkpoint generated at "saved\Jigsaw_BERT\lightning_logs\version_x\checkpoints\epoch=3-step=76.ckpt" with detoxify or try to instantiate detoxify with the "checkpoint parameter" or with a file generated by torch.save(model), it always says
What's the proper way of saving the checkpoint so it has the config and state dict with it? Or is my workaround the best way to use custom training data?