Lightning-Universe / lightning-transformers

Flexible components pairing 🤗 Transformers with :zap: Pytorch Lightning
https://lightning-transformers.readthedocs.io
Apache License 2.0
610 stars 77 forks source link

Support saving the model so that it can be reloaded from Huggingface #188

Closed dfioravanti closed 2 years ago

dfioravanti commented 3 years ago

🚀 Feature

Right now at the end of training the model is saved as a ckpt file and huggingface really does not like ckpt files. From what I can tell there is no easy way to load a checkpoint into a HF model.

Motivation

Having this feature would make deploying a model much much easier as huggingface is becoming a somewhat standard library that quite a few people know and use

Pitch

It should be possible to save the best model in a format that HF can easily load

@SeanNaren wrote in the slack that this potentially can be done within the LightningModule (or the TransformerTask) in the on_save_checkpoint hook, however we have to make sure to load correctly as well.

mathemusician commented 3 years ago

I've currently resorted to monkey-patching to temporarily fix the problem. If you have a violent allergic reaction to code that has a tendency to break easily, please do not read the following:

class ModelSwitcher:
    def __init__(self):
        self.model = pl.LightningModule() # load model here
        self.checkpoint_path = "path/to/model.ckpt"

    def load_checkpoint(self):
        old_state_dict = torch.load(self.checkpoint_path)["state_dict"]
        # the checkpoint contains a dictionary of weights that can be
        # directly mapped to the transformer weights, the names are just
        # a little different, you might need to tweak this
        for key in old_state_dict.keys():
            try:
                keys = key.split(".")[:-1]
                module_name = (
                    "model._modules['"
                    + "']._modules['".join(keys)
                    + "'].weight.data"
                )
                setattr(self, module_name, old_state_dict[key])
            except Exception as e:
                print(e, f"could not find {key} in model")

I can confirm this works for pytorch-lightning type models; but since I'm doing manipulations at the pytorch level, it might work on other projects. You have to inspect your model names to make SURE the naming conventions match up. Furthermore, there's no check for the size of the input/output weights. Please don't take this post too seriously. This is only for the truly desperate.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.