Closed nakarinh14 closed 4 years ago
@nakarinh14 you have to sample from the decoder, using the encoded keys from the encoder as context. you can also use the handy TrainingWrapper
, which includes a generate function that supports nucleus and topk sampling
import torch
from reformer_pytorch import ReformerLM
from reformer_pytorch.generative_tools import TrainingWrapper, top_p
DE_SEQ_LEN = 1024
EN_SEQ_LEN = 1024
encoder = ReformerLM(
num_tokens = 20000,
dim = 1024,
depth = 2,
heads = 8,
attn_chunks = 5,
max_seq_len = DE_SEQ_LEN,
fixed_position_emb = True,
return_embeddings = True
)
decoder = ReformerLM(
num_tokens = 20000,
dim = 1024,
depth = 2,
heads = 8,
attn_chunks = 2,
max_seq_len = EN_SEQ_LEN,
fixed_position_emb = True,
causal = True
)
x = torch.randint(0, 20000, (2, DE_SEQ_LEN)).long()
enc_keys = encoder(x)
decoder = TrainingWrapper(decoder)
yi = torch.tensor([[0], [0]]).long() # assume you are sampling batch size of 2, start tokens are id of 0
sample = decoder.generate(yi, 1024, filter_logits_fn=top_p, filter_thres=0.95, keys=enc_keys) # (2, <= 1024)
# decode the sample token ids
edit: just found a bug, so if you update the framework, the training wrapper should work with your decoder
@nakarinh14 I should probably just write out the full Reformer encoder / decoder into one class lol
@lucidrains Thank you very much for the replies! I agree it would be quite nice if there's one seperate class for encoder-decoder structure :)
This is a follow up from my comment in https://github.com/lucidrains/reformer-pytorch/issues/50. How do you make a prediction for a test example for encoder-decoder, after training with the code block mention in the issue?