apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

Support model with shared parameters save_parameters(..., deduplicate=True) to load with model variations that don't share the parameters #18717

Open zheyuye opened 4 years ago

zheyuye commented 4 years ago

Description

import mxnet as mx
from mxnet.gluon import HybridBlock, nn
import tempfile
import os
mx.npx.set_np()

class Foo(HybridBlock):
    def __init__(self, use_mlm=False):
        super().__init__()
        self.use_mlm = use_mlm
        self.l1 = nn.Dense(16)
        if self.use_mlm:
            self.l2 = nn.Dense(16)
            self.l2.share_parameters(self.l1.collect_params())

    def hybrid_forward(self, F, x):
        x = self.l1(x)
        if self.use_mlm:
            x = self.l2(x)

        return x

foo = Foo(use_mlm=True)
foo.initialize()
foo(mx.np.ones((32, 16)))
foo2 = Foo(use_mlm=False)
with tempfile.TemporaryDirectory() as dir_path:
    foo.save_parameters(os.path.join(dir_path, 'test.params'), deduplicate=True)
    parametes = mx.npx.load(os.path.join(dir_path, 'test.params'))
    print(parametes.keys())
    foo2.load_parameters(os.path.join(dir_path, 'test.params'))

Output:

>>>dict_keys(['l2.weight', 'l2.bias'])
>>>AssertionError: Parameter 'l1.weight' is missing in 'file: /tmp/tmp3a6xslz2/test.params', which contains parameters: 'l2.weight', 'l2.bias'. Set allow_missing=True to ignore missing parameters.

Here l1 and l2 are shared and thanks for the flag deduplicate, we could save shared paremeters only once as well as the dictionary correspondence using the last parameter name as key like dict_keys(['l2.weight', 'l2.bias']). There's nothing wrong with that unless we just load part parameters, as foo2 = Foo(use_mlm=False).

Of course we can solve this problem by calling L1 repeatedly instead of creating a separate layer l2 sharing weights with l1. The following scenario is fairly common in pretraind model with masked language modelling as pretrained objective

import mxnet as mx
from mxnet.gluon import HybridBlock, nn
import tempfile
import os
mx.npx.set_np()

class Foo(HybridBlock):
    def __init__(self, use_mlm=False):
        super().__init__()
        self.use_mlm = use_mlm
        self.vocab_size = 30522
        self.word_embed = nn.Embedding(input_dim=self.vocab_size,
                                       output_dim=64)

        if self.use_mlm:
            self.mlm_decoder = nn.HybridSequential()
            self.mlm_decoder.add(nn.Dense(units=64, flatten=False))
            self.mlm_decoder.add(nn.Dense(units=self.vocab_size, flatten=False))
            self.mlm_decoder[-1].share_parameters(self.word_embed.collect_params())

    def hybrid_forward(self, F, x):
        x = self.word_embed(x)
        if self.use_mlm:
            x = self.mlm_decoder(x)
        return x

foo = Foo(use_mlm=True)
foo.initialize()
foo(mx.np.ones((8,)))
foo2 = Foo(use_mlm=False)
with tempfile.TemporaryDirectory() as dir_path:
    foo.save_parameters(os.path.join(dir_path, 'test.params'), deduplicate=True)
    parametes = mx.npx.load(os.path.join(dir_path, 'test.params'))
    print(parametes.keys())
    foo2.load_parameters(os.path.join(dir_path, 'test.params'))
>>>dict_keys(['mlm_decoder.1.weight', 'mlm_decoder.0.weight', 'mlm_decoder.0.bias', 'mlm_decoder.1.bias'])

Here mlm_decoder is only used in pretraining and woube be discard when fine-tuning down-stream tasks. In the mlm_decoder, we usually need to predict the masked token by mapping back to the vocab_index through a dense where parameters are shared with word_embed. However, saving in this way results in parameters without word_embed.weight.

leezu commented 4 years ago

In summary, the feature request is that save_parameters(..., deduplicate=True) stores all the names a shared parameter is known under so that the resulting parameter file can be loaded for arbitrary variations of the original models in which a different set of parameters is shared.

It's not really a bug, because the same limitation is present in the MXNet 1.x save_parameters(..., deduplicate=True). It's just that due to internal implementation change, in 1.x the first name under which the parameter was known would be stored, whereas currently the last name under which the parameter is known is stored.