I found a bug in the code, when I use bi-directional gru the dimensions don't correspond.
Traceback (most recent call last):
File "inference.py", line 80, in
main(args)
File "inference.py", line 44, in main
samples, z = model.inference(n=args.num_samples)
File "/home/bli/Binyun/Generation/Sentence-VAE/model.py", line 153, in inference
output, hidden = self.decoder_rnn(input_embedding, hidden)
File "/home/bli/.conda/envs/Xihe/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/bli/.conda/envs/Xihe/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 819, in forward
self.check_forward_args(input, hx, batch_sizes)
File "/home/bli/.conda/envs/Xihe/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 229, in check_forward_args
self.check_hidden_size(hidden, expected_hidden_size)
File "/home/bli/.conda/envs/Xihe/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 223, in check_hidden_size
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden size (2, 10, 256), got [1, 2, 10, 256]
I found a bug in the code, when I use bi-directional gru the dimensions don't correspond.
Traceback (most recent call last): File "inference.py", line 80, in
main(args)
File "inference.py", line 44, in main
samples, z = model.inference(n=args.num_samples)
File "/home/bli/Binyun/Generation/Sentence-VAE/model.py", line 153, in inference
output, hidden = self.decoder_rnn(input_embedding, hidden)
File "/home/bli/.conda/envs/Xihe/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/bli/.conda/envs/Xihe/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 819, in forward
self.check_forward_args(input, hx, batch_sizes)
File "/home/bli/.conda/envs/Xihe/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 229, in check_forward_args
self.check_hidden_size(hidden, expected_hidden_size)
File "/home/bli/.conda/envs/Xihe/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 223, in check_hidden_size
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden size (2, 10, 256), got [1, 2, 10, 256]