unitaryai / detoxify

Trained models & code to predict toxic comments on all 3 Jigsaw Toxic Comment Challenges. Built using ⚡ Pytorch Lightning and 🤗 Transformers. For access to our API, please email us at contact@unitary.ai.
https://www.unitary.ai/
Apache License 2.0
897 stars 115 forks source link

How do you load a custom checkpoint? #53

Open bottiger1 opened 2 years ago

bottiger1 commented 2 years ago

Hello I want to train the network on my own samples but I'm finding it quite difficult.

Right now I edited Toxic_comment_classification_BERT.json to point to my own training and test csv. Then I have to edit train.py to manually save the model object inside ToxicClassifier at the end of the training.

torch.save(model.model, 'custom.pt')

Then I have load the file manually, instantiate the normal instance of detoxify, and then replace the internal model object with the saved version to get it to work.

saved = torch.load('custom.pt')
d = detoxify.Detoxify('original')
d.model = saved

If I try to load a checkpoint generated at "saved\Jigsaw_BERT\lightning_logs\version_x\checkpoints\epoch=3-step=76.ckpt" with detoxify or try to instantiate detoxify with the "checkpoint parameter" or with a file generated by torch.save(model), it always says

Checkpoint needs to contain the config it was trained with as well as the state dict

What's the proper way of saving the checkpoint so it has the config and state dict with it? Or is my workaround the best way to use custom training data?

laurahanu commented 2 years ago

Hello! To load the custom pretrained model you would need to save both the config and the state dict in the checkpoint e.g. torch.save({"config": custom_config, "state_dict": custom_model_state_dict}). Make sure you only save the state_dict and not the whole PL checkpoint.

auadams commented 1 year ago

I just went down this path if pain and suffering, I want to post this for anyone else who wants to use this AI to train data for their own purposes. It was a really painful experience that let me learn how detoxify works from a tensor level lol.

    trainer.fit(model, data_loader, valid_data_loader)
    torch.save({"config": model.config, "state_dict": model.state_dict()},"model.pt")

my original idea after seeing this post was just to put this save line after the trainfer.fit. The issue i have been running to is that you can't use model.state_dict() because everything in the state dictonary is prefixed with model. i.g model.bert.encoder.layer.8.output.LayerNorm.weight needs to be converted to bert.encoder.layer.8.output.LayerNorm.weight. after doing all of the translations of every element in the state dictonary i could sucessfully run the checkpoint method in detoxify.

you need to add this to the bottom of train.py

    trainer.fit(model, data_loader, valid_data_loader)
    statedict = {}
    for param_tensor in model.state_dict():
        if "model.bert." in param_tensor:
            newname = param_tensor.replace("model.","")
            statedict[newname] = model.state_dict()[param_tensor]
    statedict["classifier.weight"] = model.state_dict()["model.classifier.weight"]
    statedict["classifier.bias"] = model.state_dict()["model.classifier.bias"]
    torch.save({"config": model.config, "state_dict": statedict},"model.pt")

then you can just import your model using detoxify like below

ai = Detoxify(checkpoint="model.pt")

also for other models like Robert or albert you just need to replace the bert in the if statement above.