lucidrains / reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
MIT License
2.1k stars 254 forks source link

Predicting with Encoder-Decoder structure #64

Closed nakarinh14 closed 4 years ago

nakarinh14 commented 4 years ago

This is a follow up from my comment in https://github.com/lucidrains/reformer-pytorch/issues/50. How do you make a prediction for a test example for encoder-decoder, after training with the code block mention in the issue?

lucidrains commented 4 years ago

@nakarinh14 you have to sample from the decoder, using the encoded keys from the encoder as context. you can also use the handy TrainingWrapper, which includes a generate function that supports nucleus and topk sampling

import torch
from reformer_pytorch import ReformerLM
from reformer_pytorch.generative_tools import TrainingWrapper, top_p

DE_SEQ_LEN = 1024
EN_SEQ_LEN = 1024

encoder = ReformerLM(
    num_tokens = 20000,
    dim = 1024,
    depth = 2,
    heads = 8,
    attn_chunks = 5,
    max_seq_len = DE_SEQ_LEN,
    fixed_position_emb = True,
    return_embeddings = True
)

decoder = ReformerLM(
    num_tokens = 20000,
    dim = 1024,
    depth = 2,
    heads = 8,
    attn_chunks = 2,
    max_seq_len = EN_SEQ_LEN,
    fixed_position_emb = True,
    causal = True
)

x  = torch.randint(0, 20000, (2, DE_SEQ_LEN)).long()
enc_keys = encoder(x)

decoder = TrainingWrapper(decoder)

yi = torch.tensor([[0], [0]]).long() # assume you are sampling batch size of 2, start tokens are id of 0
sample = decoder.generate(yi, 1024, filter_logits_fn=top_p, filter_thres=0.95, keys=enc_keys) # (2, <= 1024)
# decode the sample token ids

edit: just found a bug, so if you update the framework, the training wrapper should work with your decoder

lucidrains commented 4 years ago

@nakarinh14 I should probably just write out the full Reformer encoder / decoder into one class lol

nakarinh14 commented 4 years ago

@lucidrains Thank you very much for the replies! I agree it would be quite nice if there's one seperate class for encoder-decoder structure :)

lucidrains commented 4 years ago

@nakarinh14 https://github.com/lucidrains/reformer-pytorch#reformer-encoder-decoder-architecture done!