lucidrains / routing-transformer

Fully featured implementation of Routing Transformer
MIT License
282 stars 29 forks source link

Building and training a RoutingTransformerEncDec from pre-trained RoutingTransformerLMs #21

Closed AliOskooeiTR closed 3 years ago

AliOskooeiTR commented 3 years ago

I am trying to build and train an encoder-decoder from pretrained routing transformer LMs. The way I approached it was to replace the encoder and decoder in a RoutingTransformerEncDec with the pre-trained RoutingTransformerLMs as follows:

enc_dec.enc=pretrained_lm
enc_dec.dec=AutoregressiveWrapper(pretrained_lm)

and then try to train the enc_dec as normal when I get the following error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-9-681d3315d6dc> in <module>
    147         grad_accum_steps=1,
    148         temperature=1,
--> 149         model_suffix=''
    150 
    151     )

~/projects/trlabs_routing_transformer/routing_sum/train_and_eval.py in train_routing_single(epoch, model, tokenizer, train_chunk_bucket, val_data_bucket, model_dir, optimizer, lr, max_seq_len, pred_target_len, src_pad_len, tgt_pad_len, max_src_len, max_tgt_len, log_interval, eval_interval, save_interval, train_logger, global_step, grad_accum_steps, temperature, model_suffix)
    469         train_seq_out = padded_target[:, :max_seq_len].to(device)
    470         loss, aux_loss = model(train_seq_in, train_seq_out, return_loss=True)
--> 471         loss.backward()
    472         aux_loss.backward()
    473         train_loss += loss.item()

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    196                 products. Defaults to ``False``.
    197         """
--> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    199 
    200     def register_hook(self, hook):

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
--> 100         allow_unreachable=True)  # allow_unreachable flag
    101 
    102 

RuntimeError: new kmeans has not been supplied

I would appreciate any feedback on what may be the problem or what is the best way to build an enc_dec from pretrained LM checkpoints.

lucidrains commented 3 years ago

@AliOskooeiTR Hi Ali, it is tricky because in the encoder / decoder setting, the hook for updating the kmeans must be moved from the individual transformers to the EncDec class, as shown https://github.com/lucidrains/routing-transformer/blob/master/routing_transformer/encoder_decoder.py#L78 Let me work on this later this weekend and see if there is a clean way to do this

lucidrains commented 3 years ago

@AliOskooeiTR also, as shown https://github.com/lucidrains/routing-transformer/blob/master/routing_transformer/encoder_decoder.py#L59

lucidrains commented 3 years ago

@AliOskooeiTR if you want to try version 1.4.3, this should fix it

enc.cancel_kmeans_update()  # cancel kmeans update for enc and dec
dec.cancel_kmeans_update()

encdec.enc = enc
encdec.dec = dec

encdec.register_kmeans_update() # reregister kmeans
AliOskooeiTR commented 3 years ago

@lucidrains I tried this:

pretrained_lm.cancel_kmeans_update()  # cancel kmeans update for enc and dec
enc_dec.enc = pretrained_lm
enc_dec.dec = AutoregressiveWrapper(pretrained_lm)

enc_dec.register_kmeans_update() # reregister kmeans

And unfortunately I still get the same error:

RuntimeError: new kmeans has not been supplied

lucidrains commented 3 years ago

@AliOskooeiTR ohh, I see the problem, I don't think you can use the same pretrained LM for both encoder and decoder. Instead of doing it that way, why not do something like

enc_dec.enc.load_state_dict(pretrained_lm.state_dict())
enc_dec.dec.load_state_dict(pretrained_lm.state_dict())
AliOskooeiTR commented 3 years ago

@lucidrains Thanks for the suggestion. I tried this:

enc_dec.enc.load_state_dict(pretrained_lm.state_dict())
enc_dec.dec.load_state_dict(pretrained_lm.state_dict())

It would give me missing key error as the state dict keys for enc_dec.enc and pretrained_lm were different. Surprisingly even the number of keys were quite different. I did not understand how this could be since enc_dec.enc and pretrained_lm are both RoutingTransformerLMs.

I then replaced encdec state dicts just to see what happens:

pretrained_lm.cancel_kmeans_update()
pretrained_dec = AutoregressiveWrapper(pretrained_lm)
enc_dec.enc.state_dict=pretrained_lm.state_dict
enc_dec.dec.state_dict=pretrained_dec.state_dict
enc_dec.register_kmeans_update() 

Trying to train this gave me the same kmeans error:

RuntimeError: new kmeans has not been supplied

AliOskooeiTR commented 3 years ago

Just closing this issue as I have figured out why it wasn't possible to use the same LM for both the encoder and the decoder. The decoder LM must receive context and be causal. This results in the encoder having a different architecture and state dictionary and not exchangeable with the encoder LM.