allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.76k stars 2.25k forks source link

Saving and loading an in-memory `Model` & `config.json` is very hard #1289

Closed DeNeutoy closed 4 years ago

DeNeutoy commented 6 years ago

Here is a script I just wrote to load a model with the custom LSTM kernel and re-pack it into a new model so we can put it in the demo and distribute it. Turned out that we don't have any functionality for creating an archive from an in memory Model and Config and doing so was pretty messy. Hacking stuff like this is pretty key for e.g transfer learning and research in general - "what happens if I take part of this model and put it in this other one".

from allennlp.models.archival import load_archive, archive_model
from allennlp.common.params import Params
from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder

import tarfile
import torch
import numpy as np
import os
import shutil
import json

root_dir = '/net/efs/aristo/allennlp/srl_lm/emnlp/srl/5b-elmo/'
fname = '/net/efs/aristo/allennlp/srl_lm/emnlp/srl/5b-elmo/model.tar.gz'

archive = load_archive(fname)

model = archive.model
params = archive.config
new_params = params.duplicate()

encoder_params = params.get("model").get("encoder").as_dict()
new_encoder_params = encoder_params
new_encoder_params["type"] = "alternating_lstm"
new_encoder_params["use_input_projection_bias"] = False
new_params["model"]["encoder"] = new_encoder_params

old_encoder = model.encoder._module
new_encoder = Seq2SeqEncoder.from_params(Params(new_encoder_params.copy()))

weight_index = 0
bias_index = 0

num_layers = 8
for layer_index in range(num_layers):
    layer = getattr(new_encoder._module, 'layer_%d' % layer_index)
    input_weight = layer.input_linearity.weight
    state_weight = layer.state_linearity.weight
    bias = layer.state_linearity.bias

    input_weight.data.t().copy_(old_encoder.weight.data[weight_index: weight_index + input_weight.nelement()].view_as(input_weight.t()))
    weight_index += input_weight.nelement()

    state_weight.data.t().copy_(old_encoder.weight.data[weight_index: weight_index + state_weight.nelement()].view_as(state_weight.t()))

    weight_index += state_weight.nelement()

    bias.data.copy_(old_encoder.bias.data[bias_index:bias_index + bias.nelement()])
    bias_index += bias.nelement()

model.encoder = new_encoder

serialisation_dir = "./new_model/"
os.makedirs("./new_model/")

shutil.copytree(os.path.join(root_dir, "vocabulary"), os.path.join(serialisation_dir, "vocabulary"))

with open(os.path.join(root_dir, "files_to_archive.json"), "r") as f:
    fta = json.loads(f.read())

print(fta)
with open(os.path.join(serialisation_dir, "config.json"), "w") as f:
    f.write(json.dumps(new_params.as_dict(quiet=True)))

model_state = model.state_dict()
torch.save(model_state, os.path.join(serialisation_dir, "best.th"))

archive_model(serialisation_dir, files_to_archive=fta)
schmmd commented 6 years ago

We would write an archive_model_from_memory function.

zzyxzz commented 6 years ago

Any updates on this thread? I'd like to save a model trained. From the examples in "train" command, without using "allennlp train" command, we have to save vocabulary, model parameters, configurations separately. Am I right?

zhxt95 commented 5 years ago

This function is really useful. Hope it can be available soon.

phimit commented 5 years ago

The tricky part right now is providing a config during save for a custom model. I cant even find an example.

matt-gardner commented 5 years ago

Do we have anything better to say here? I know we've talked about it a couple of times since this issue was opened, but I don't remember the current state of things.

schmmd commented 4 years ago

This is a necessary precondition to better supporting using AllenNLP via python code vs. configuration. Can we leverage PyTorch's save functionality? Can we just use Python's dill?

matt-gardner commented 4 years ago

We have a section on this in the upcoming course, showing that it's now pretty easy, with or without config files. It's got example code. I'm closing this issue as finished.