google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

"TypeError: missing a required argument: 'x'" even though there is no 'x' in the code #460

Open minsoo2018 opened 2 years ago

minsoo2018 commented 2 years ago

Hi, everyone. I got "TypeError: missing a required argument: 'x' even though there is no x in the code.

def concat(a, b, dim = -1):
    return jnp.concatenate((a, b), axis = dim)
class MyLSTM(hk.Module):
  def __init__(self, hidden_size, output_size):
    super().__init__()
    self.hidden_size = hidden_size
    self.output_size = output_size

    self.LSTM = hk.LSTM(self.hidden_size)
    self.Linear_cat = hk.Linear(self.hidden_size)
    self.Linear_final = hk.Linear(self.output_size)
    self.relu = jax.nn.relu()

  def __call__(self, seq, label_encode):
    batch_size = seq.shape[0]
    hidden, cell = hk.dynamic_unroll(self.LSTM, seq, self.LSTM.initial_state(batch_size), time_major = False)
    return self.Linear_final(self.relu(self.Linear_cat(concat(hidden[:,-1], label_encode))))
def forward_MyLSTM(seq, label_encode):
  lstm = MyLSTM(HIDDEN_SIZE, OUTPUT_SIZE)
  return lstm(seq, label_encode)
LSTMnet = hk.transform(forward_MyLSTM)
rng = jax.random.PRNGKey(428) # Create a pseudo-random number generator (PRNG) key given an integer seed.
sample_x, sample_y = next(train_ds)
encode = sample_y[:,4:]
params = LSTMnet.init(rng, sample_x, encode)

The purpose of this code is simple: Use jnp.concatenate to combine the last hidden vector of LSTM (state) with extra information (label_encode). I think there are no code errors up to this point. But I got the following error messages.

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-76-cf73154e02e8>](https://localhost:8080/#) in <module>()
      4 encode = sample_y[:,4:]
----> 5 params = LSTMnet.init(rng, sample_x, encode)

21 frames
UnfilteredStackTrace: TypeError: missing a required argument: 'x'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
[/usr/lib/python3.7/inspect.py](https://localhost:8080/#) in _bind(self, args, kwargs, partial)
   2928                             msg = 'missing a required argument: {arg!r}'
   2929                             msg = msg.format(arg=param.name)
-> 2930                             raise TypeError(msg) from None
   2931             else:
   2932                 # We have a positional argument to process

TypeError: missing a required argument: 'x'

As you can see, there is no x in the code. I double-checked the dimensions of every array in my code. Could somebody suggest how to fix this issue?

Thank you for reading.

minsoo2018 commented 2 years ago

Luckily I solved this issue by using hk.nets.MLP, but I still don't know why the issue was resolved.

class MyLSTM(hk.Module):
  def __init__(self, hidden_size, output_size):
    super().__init__()
    self.hidden_size = hidden_size
    self.output_size = output_size

    self.LSTM = hk.LSTM(self.hidden_size)
    self.MLP = hk.nets.MLP([self.hidden_size, self.output_size]) # activation defaults : relu

  def __call__(self, seq, label_encode):
    batch_size = seq.shape[0]
    hidden, state = hk.dynamic_unroll(self.LSTM, seq, self.LSTM.initial_state(batch_size), time_major = False)
    return self.MLP(concat(hidden[:,-1], label_encode)), state

Could someone give a hint about the reason why all the issues are resolved??

Thank you for reading.