Open SixtyTrees opened 3 months ago
Hey @SixtyTrees
Sequential does what it basically says, creates a model that is a sequence of layers. And that is just syntactic sugar for feeding the output of a layer as input to the next one.
That being said, you cannot implement skip connections inside the nn.Sequential
.
For that you have to subclass nn.Module
and define the forward pass yourself.
Note that in order to create a skip connection the output dimension must match the dimensions the layer expects as input. There are 2 ways you can implement skip connections (that I know of), the first one is through concatenation
, and the second one (with an example below) is using addition
.
In a ResNet-like (which uses addition) you'd need your linear2
and linear3
output dims to match the input dim of linear4
. This means:
nn.Linear(10, **10**), # linear2
...
nn.Linear(**10**, **10**), # linear3 -- note that you have to change the in dim of linear3 as well
...
nn.Linear(**10**, 20), # linear4
In a UNet-like architecture (which uses concatenation) you'd need the sum of the out dims of linear2
and linear3
to match the input dim of linear4
:
nn.Linear(10, **15**), # linear2
...
nn.Linear(15, **10**), # linear3
...
nn.Linear(**25**, 20), # linear4
Here's a simple example of how to create a skip connection using addition:
class MySimpleSkipModel(nn.Module):
def __init__(self):
self.linear1 = nn.Linear(30, 10)
self.linear2 = nn.Linear(10, 10)
self.linear3 = nn.Linear(10, 5)
def forward(self, inputs):
l1_out = self.linear1(inputs)
l2_out = self.linear2(l1_out)
l3_out = self.linear3(l1_out + l2_out) # using the outputs of both linear1 and linear2, use torch.cat for concat
return l3_out
Say, I have sequential model like below and I want
linear4
to accept input from bothlinear2
andlinear3
. How to do this? An example would be UNet architecture.