Open minsoo2018 opened 2 years ago
Luckily I solved this issue by using hk.nets.MLP
, but I still don't know why the issue was resolved.
class MyLSTM(hk.Module):
def __init__(self, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.LSTM = hk.LSTM(self.hidden_size)
self.MLP = hk.nets.MLP([self.hidden_size, self.output_size]) # activation defaults : relu
def __call__(self, seq, label_encode):
batch_size = seq.shape[0]
hidden, state = hk.dynamic_unroll(self.LSTM, seq, self.LSTM.initial_state(batch_size), time_major = False)
return self.MLP(concat(hidden[:,-1], label_encode)), state
Could someone give a hint about the reason why all the issues are resolved??
Thank you for reading.
Hi, everyone. I got
"TypeError: missing a required argument: 'x'
even though there is nox
in the code.The purpose of this code is simple: Use
jnp.concatenate
to combine the last hidden vector of LSTM (state
) with extra information (label_encode
). I think there are no code errors up to this point. But I got the following error messages.As you can see, there is no
x
in the code. I double-checked the dimensions of every array in my code. Could somebody suggest how to fix this issue?Thank you for reading.