lucidrains / routing-transformer

Fully featured implementation of Routing Transformer
MIT License
282 stars 29 forks source link

top_p returns wrong values and re-orders the data #1

Closed tomweingarten closed 4 years ago

tomweingarten commented 4 years ago
from routing_transformer.autoregressive_wrapper import top_p, top_k
import torch
import torch.nn.functional as F

test_tensor = torch.tensor([[0.1, 0.2, 0.15, 0.4, 0.3, 0.001, 0.01]])
threshold=0.3

print(test_tensor)
print(F.softmax(test_tensor, dim=-1))

print("Top K")
print(top_k(test_tensor, thres=threshold))
print(F.softmax(top_k(test_tensor, thres=threshold), dim=-1))

print("Top P")
print(top_p(test_tensor, thres=threshold))
print(F.softmax(top_p(test_tensor, thres=threshold), dim=-1))

Output:

tensor([[0.1000, 0.2000, 0.1500, 0.4000, 0.3000, 0.0010, 0.0100]])
tensor([[0.1325, 0.1464, 0.1393, 0.1789, 0.1618, 0.1200, 0.1211]])
Top K
tensor([[  -inf, 0.2000, 0.1500, 0.4000, 0.3000,   -inf,   -inf]])
tensor([[0.0000, 0.2338, 0.2224, 0.2855, 0.2584, 0.0000, 0.0000]])
Top P
tensor([[  -inf,   -inf, 0.3000,   -inf, 0.4000,   -inf,   -inf]])
tensor([[0.0000, 0.0000, 0.4750, 0.0000, 0.5250, 0.0000, 0.0000]])

Thanks for writing this library! I think there is a bug in top_p, with two symptoms:

  1. The wrong results are filtered. It defines the threshold in the opposite way as top_k. So setting thres=0.9 results in everything being returned until a cumulative probability of 0.1 is reached.
  2. The results themselves are gathered twice (instead of gathered and scattered) and as a result the returned tensor's values are distributed in a nearly random order.

I'll send over a PR in a few, hope it's helpful!

lucidrains commented 4 years ago

@tomweingarten how is this framework panning out for you?

tomweingarten commented 4 years ago

Very well so far! Only just beginning to run tests to compare it to XLNet but it seems very promising so far.

lucidrains commented 4 years ago

awesome! Aurko (first author) and @AranKomat would be pleased to hear that! do let me know if you plan on training anything bigger (OpenWebText). Another researcher has told me he met instability while training on that, so if you run into that obstacle and overcome it, please share!

tomweingarten commented 4 years ago

Interesting, thanks for the heads up, I'll be on the lookout. Do you know what their configuration settings were?

lucidrains commented 4 years ago

I asked him the same question, here is his response. I'm sure he wouldn't mind if I pasted it

Glad to hear that - I have been training 2 sets of runs on PG-19 one with length 4k and one with 8k. For 8k I have been doing, you might have to play with the num_decoder_layers depending on your device memory:

  hparams.max_length = 8192
  hparams.batch_size = 8192
  hparams.max_target_length = 8192
  hparams.hidden_size = 1024
  hparams.embedding_dims = 1024
  hparams.filter_size = 4096
  hparams.local_num_heads = 6
  hparams.sparsity_cluster_num_heads = 2
  hparams.num_decoder_layers = 25
  hparams.sparsity_skip_first = 23
  hparams.sparsity_cluster_size = 16
  hparams.query_shape = (256,)
  hparams.memory_flange = (256,)
  hparams.attention_dropout = 0.0
  hparams.relu_dropout = 0.0
  hparams.dropout = 0.0
  hparams.input_dropout = 0.0
  hparams.sparsity_cluster_attention_window = 512
  hparams.max_relative_position = 513
  hparams.weight_decay = 0

For 4k I can train with larger num_decoder_layers up to 28 and can match the 36 layer compressive transformer from deepmind.
AranKomat commented 4 years ago

@tomweingarten I think the product of batch_size and max_length (i.e. 8192^2 tokens per minibatch) may be a bit too large for the training to be compute-optimal. I think it would be better if the batch_size is reduced, so that the product will not exceed 5 million or so. But I could be wrong.

Longformer performed better than Roberta, but note that probably some (or most) MLM downstreaming tasks won't benefit from longer-range dependency. If your result so far is satisfying, then this remark may not be applicable. Though @lucidrains suggested to use RT on OpenWebText, since the avg length of a sample of OWT is mere 350 words, as with any CC-derived dataset, the use of RT heads may not be worth the cost, and using local attn heads only may suffice.

For the case of PG-19 and Wikitext-103, I observed the following trick improved the performance dramatically (not sure about large-scale setting, though). In the case of Wikitext-103, once you reshape your dataset from (batch_size x num_iters x max_length) into (batch_size, num_iters, max_length), you can shuffle w.r.t. the dimension of num_iters. This makes the minibatch of consecutive iterations uncorrelated to each other, which improved the training in my case. If you ever try it, I'd really appreciate your feedback on this.

lucidrains commented 4 years ago

@tomweingarten Hi Tom, just wanted to let you know that I've made an important update in the latest version per the author's suggestion