mlech26l / ncps

PyTorch and TensorFlow implementation of NCP, LTC, and CfC wired neural models
https://www.nature.com/articles/s42256-020-00237-3
Apache License 2.0
1.86k stars 297 forks source link

Stacking wirings.FullyConnected with nn.Linear layers slows the training by 120x #29

Closed sprakashdash closed 1 year ago

sprakashdash commented 2 years ago

Before describing my issue, I would like to thank the authors for the incredible paper: "Neural Circuit Policy" where you have shown how LTC model artificial neurons can perform autonomous driving robustly with the wiring architecture inspired from C.Elegans.

Motivated by the provided Keras code, I tried stacking up the wirings.FullyConnected() layer after two linear layers. Here is a piece of code:

class QNetwork_w_LTC(nn.Module):
    def __init__(self, env):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 64)
        self.fc2 = nn.Linear(64, 64)
        ###########################
        self.wiring = kncp.wirings.FullyConnected(units=16, output_dim=1)
        self.ltc_cell = LTCCell(wiring=self.wiring, in_features=5)
        ##########################
        # self.fc3 = nn.Linear(64, 1)

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        #######################
        x = x.unsqueeze(-1)
        ltc_sequence = RNNSequence(self.ltc_cell)
        x = ltc_sequence.forward(x)
        #######################
        return x

In place of the generic Q-network used in Deep RL algos like this:

class QNetwork(nn.Module):
    def __init__(self, env):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

But during training, the vanilla nn.Linear layers took 3 hours for 1.5 million timesteps while the nn.Linear layers stacked with wirings.Fullyconnected() layers took 5 hours for just 20k timesteps.

According to my guess the decrease in speed is due to the following reasons:

  1. Maybe I am using much more parameters than I should. (I really could not understand what its the use of output_dim and in_features)?
  2. Maybe implementing wirings.py file in PyTorch, like making every matrix/tensor as torch.Tensor could increase the speed?

Could you please specify if thre are any other options to speed up the code!

Thnaks in advance for helping me out!

mlech26l commented 2 years ago

Hi @sprakashdash

Glad that you liked our work!

There are a few issues in the QNetwork_w_LTC:

First, in_features are the size of the input tensor (=number of input features), which is 64 in your case.

Second, RNNSequence processes entire sequences of samples instead of just individual samples. So, the input should have the size/shape (batch size, sequence length, in_features). What your code is doing is to represent the input as (batch size, observation dim + action dim, 1), i.e., instead of processing the input at ones, it looks at each component after another. This is probably the reason why your code is so slow. (Note that the LTC is expect to be 10x slower than a single layer due to the ODE solver).

Generally, the Q learning are doing does not seem like it has a temporal component, i.e., it processes individual samples instead of also providing past observations. In such context, I would not expect any advantage of RNNs (such as the LTC) over just normal feedforward networks. If you want to give it a try anyway, you can create sequences of length 1 by replacing

#x = x.unsqueeze(-1)
# with 
x = x.unsqueeze(1)
sprakashdash commented 2 years ago

@mlech26l Thanks a lot for the heads up. Now it seems the training is 10x slower than generic nn.Linear layers! I understand that as LTC cells are similar to RNN cells (in their abstraction). So just training a random batch would not produce an out of the box result. But I was thinking of keeping the causality of the replay buffer and picking a random causal batch of <s,a,r,s'> for training.

The last thing I would like to ask is about the learning rate of LTC cells. I found in all the notebooks that the learning rate is 0.01. While most RL function approximations are kept at a very low learning rate (3e-4). So does the learning rate depends on the application (like supervised learning vs RL) or it depends on the type of NN (LTC vs ANN)?

Could you also tell me what was the learning rate when you trained LTC neurons on real car data for autonomous driving?

mlech26l commented 2 years ago

The LTC usually requires higher learning rates than feedforward networks.

So, in our examples, 0.01 or 0.005 (autonomous driving) worked okay for supervised learning.

As you mentioned, RL typically works better with smaller learning rates, so I expect 3e-3 to 3e-4 would be a good start.

sprakashdash commented 2 years ago

As you previously mentioned, there won't be any advantage of LTCs over ANNs and there is no temporal component in RL and the batch is chosen at random, but at least I still got to train the Qnet with LTC and ActorNet with ANN. But when I tried to run my experiments for QNet and Actor both with LTC The actor is not able to learn. Here is the previous ActorNet:

class Actor(nn.Module):
    def __init__(self, env):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return torch.tanh(self.fc_mu(x))

Here is the ActorNet with LTC:

class Actor(nn.Module):
    def __init__(self, env):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        # self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))
        ###########################
        self.wiring = kncp.wirings.FullyConnected(units=8, output_dim=np.prod(env.single_action_space.shape))
        self.ltc_cell = LTCCell(wiring=self.wiring, in_features=256)
        #replace the self.fc_mu layer
        ##########################

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.unsqueeze(1)
        ltc_sequence = RNNSequence(self.ltc_cell)
        x = ltc_sequence.forward(x)
        x = torch.tanh(x).squeeze()
        return x

As you pointed out, I have kept the in_features to be the hidden state dimension of the NN, that is 256. But you have also mentioned that the shape of the input tensor to the LTC Sequence should be [batch size, sequence length, in_features]. Now the shape of the input tensor to ltc_sequence is [batch size, 1, 256] but I think it should be [batch size, action_dim, 256]. Could you help me understand how to create such a dimension of tensor?

mlech26l commented 2 years ago

Hi, [batch size, 1, 256] is the correct shape, i.e., a batch of length 1 sequences, each having 256 features. The output tensor of the LTC network should be then [batch size,1 , action_dim]. The squeeze operation then removes the sequence dimension, leaving a [batch size, action_dim] tensor.

In the above code, you are using 8 neurons for the LTC. This is very small and could be one of the reasons why the model is not learning. Try maybe 64 or 128. If that also does not work, maybe change the learning rate.