The goal for this was to have a SequentialComponent that would have in_dims and out_dims properties that are used elsewhere as well as making it cleaner to get the bias positions.
I was mistaken about the use-case for the out_dims. Currently thinking that this PR will only be useful if it makes bias positions much better.
Regardless, there are still these issues with the current commits:
The in_dims do not adjust to whether the bias was folded for the previous module. So we get shape mismatches. This is because, in this implementation, we set the in_dims when initialising a SequentialComponent. But folding in biases happens after, so we would need to add something that updates the in_dims when the bias was folded in the previously layer. Not sure of a nice way to do this, the whole thing is adding complexity.
Note that my first implementation stored the previous component inside each component. But this was an issue because when you do state_dict(), it will print out that previous component (and thus all components in the sequence) twice. A better solution there would have been to create a dictionary somewhere and just stored a pointer to the previous module instance. That seemed a little messy, also I thought I only needed the out_dims of the previous module, so I went about storing those (equivalent to the in_dims of the current module) directly. Then I realised the issue about needing the folded_bias thing.
I think it's only worth continuing here if we deem it to be very useful for centering. Otherwise we can bin this.
Create Sequential Component
Description
Related Issue
Motivation and Context
The goal for this was to have a SequentialComponent that would have in_dims and out_dims properties that are used elsewhere as well as making it cleaner to get the bias positions.
I was mistaken about the use-case for the out_dims. Currently thinking that this PR will only be useful if it makes bias positions much better.
Regardless, there are still these issues with the current commits:
Note that my first implementation stored the previous component inside each component. But this was an issue because when you do state_dict(), it will print out that previous component (and thus all components in the sequence) twice. A better solution there would have been to create a dictionary somewhere and just stored a pointer to the previous module instance. That seemed a little messy, also I thought I only needed the out_dims of the previous module, so I went about storing those (equivalent to the in_dims of the current module) directly. Then I realised the issue about needing the folded_bias thing.
I think it's only worth continuing here if we deem it to be very useful for centering. Otherwise we can bin this.
How Has This Been Tested?
Does this PR introduce a breaking change?