schelotto / Neural_Speed_Reading_via_Skim-RNN_PyTorch

PyTorch implementation of "Neural Speed Reading via Skim-RNN"
MIT License
18 stars 9 forks source link

TypeError: argmax() missing 1 required positional argument: 'input' #8

Open MrBellamonte opened 5 years ago

MrBellamonte commented 5 years ago

I have the following error:

Loading data... The device argument should be set by using torch.device or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu. The device argument should be set by using torch.device or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.

Parameters: BATCH_SIZE=16 CUDA=False DEVICE=0 DROPOUT=0.5 EMBED_DIM=300 EPOCHS=256 GAMMA=0.01 HIDDEN_LAYER=200 LARGE_CELL_SIZE=100 LOG_INTERVAL=1 LR=0.001 N_CLASS=2 NUM_LAYERS=1 PREDICT=None SAVE_DIR=snapshot/2019-04-04_11-53-47 SAVE_INTERVAL=500 SHUFFLE=False SMALL_CELL_SIZE=5 SNAPSHOT=None TAU=0.5 TEST=False TEST_INTERVAL=100 VOCAB_SIZE=13962 WORD_DICT=<torchtext.vocab.Vocab object at 0x129d1e0f0> 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 433/433 [00:29<00:00, 14.67it/s] Epoch [1/256] | Step [434/433] | Loss: 0.5634 | Acc: 0.6202 | Precision: 0.8111 | Recall: 0.2686 <class 'torch.Tensor'> Traceback (most recent call last): File "main.py", line 151, in logits, h_stack, Q_stack = skim_rnn_classifier(sent) File "/Users/simonschonenberger/PycharmProjects/ETH/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/Users/simonschonenberger/PycharmProjects/ETH/CUHK/IntroDL/SKIM-RNN2/skim_rnn.py", line 86, in forward r_t = self.gumbel_softmax(p_t, tau).unsqueeze(1) File "/Users/simonschonenberger/PycharmProjects/ETH/CUHK/IntroDL/SKIM-RNN2/skim_rnn.py", line 60, in gumbel_softmax Q_t = torch.argmax() TypeError: argmax() missing 1 required positional argument: 'input'

When I check the code, there is no argument in the argmax:

def gumbel_softmax(self, x, tau = 1.0):
    if self.training:
        u = torch.rand_like(x)
        g = -torch.log(-torch.log(u))
        tau_inverse = 1. / tau
        r_t = F.softmax(g * tau_inverse, -1)
        return r_t
    else:
        Q_t = torch.argmax()
        return Q_t.float()
  1. Why is there no argument?
  2. What could be the reason that I have this error while nobody else seems to have it?