schmidt-ju / crat-pred

Other
62 stars 11 forks source link

Issue understanding the LSTM encoder definition #8

Closed louis-dv closed 1 year ago

louis-dv commented 1 year ago

Hello, First, thank you very much for the paper and code. It is very interesting.

I am not sure I correctly understand the way you define your LSTM encoder.

Indeed, looking at the code here :

    def forward(self, lstm_in, agents_per_sample):
        # lstm_in are all agents over all samples in the current batch
        # Format for LSTM has to be has to be (batch_size, timeseries_length, latent_size), because batch_first=True

        # Initialize the hidden state.
        # lstm_in.shape[0] corresponds to the number of all agents in the current batch
        lstm_hidden_state = torch.randn(
            self.num_layers, lstm_in.shape[0], self.hidden_size, device=lstm_in.device)
        lstm_cell_state = torch.randn(
            self.num_layers, lstm_in.shape[0], self.hidden_size, device=lstm_in.device)
        lstm_hidden = (lstm_hidden_state, lstm_cell_state)

        lstm_out, lstm_hidden = self.lstm(lstm_in, lstm_hidden)

        # lstm_out is the hidden state over all time steps from the last LSTM layer
        # In this case, only the features of the last time step are used
        return lstm_out[:, -1, :]

I assume that lstm_hidden_state and lstm_cell_state will be different at each forward pass as they are going to be random tensors.

I actually tested the code on a trained model and I see that on evaluation mode for the exact same input batch, the result of lstm_out are actually different.

Am I misunderstanding something?

Thank you for the help!

schmidt-ju commented 1 year ago

Hey Louis,

thank you so much for the interesting question.

You are right, lstm_hidden_state and lstm_cell_state are generated randomly from a normal distribution. Hence, they are different at each forward pass. This is intentional!

For identical input batches and identical lstm_hidden_state and lstm_cell_state, the output of the LSTM should be consistent.

Using a noise as the initial values of the two states is just one way to initialize them. I had good experience with this during my early experiments. Using zero states is probably the default way to do it. If you want to ensure that your output is always the same, this is the way to go.

Please update me, if you do some experiments with zero initialization. I am super interested in the results :D

Hope this helps.

Julian