Closed danbraunai-apollo closed 9 months ago
Just realised it's in_tuple_dims not out_tuple_dims. We can calculate in_tuple_dims based on the input to the hook_fn, so the feature proposed above won't be useful for that.
The above feature only seems useful for getting the dimension of the logits, as well as getting the out_pos_size argument that's used for initialising inner_token_sums
inside integrated_gradient_trapezoidal_jacobian_squared
.
This isn't that useful. I'm going to downgrade to low priority.
We would also like this functionality to remove the in_tuple_dims
from integrated_gradient_trapezoidal_jacobian_functional
and integrated_gradient_trapezoidal_jacobian_squared
and partial-ing it into module_hat_partial
We're also interested in knowing which dimensions are bias dimensions in the input and output of modules. This is needed for mean centering https://github.com/ApolloResearch/rib/pull/249/
This turned out to be quite messy to implement, perhaps messier than the problem it solves (see #272). For that reason, we're going to stick with the original solution.
There are two places in the code where we need to know the dimensions of the input/output of a module. One is in
in_tuple_dims
inhook_fns.interaction_edge_pre_forward_hook_fn
, the other isfinal_node_dim
incalculate_interaction_rotations
(oh I forgot, there is also initialisinginner_token_sums
inintegrated_gradient_trapezoidal_jacobian_squared
In (many cases within) both of these, we calculate the out_dim by passing data through the modules and measuring the output. This is gross. Instead, we can do the following:
Create a RIBModule or SequentialModule class which inherits from nn.Module and has an out_dims property that must be set (can make it also inherit from ABC to enforce setting of the out_dim property).
We make all our components inherit from RIBModule/SequentialModule and define that property for themselves.
We add a property method to the MultiSequental class called "out_dims", which just runs self.out_dims on the final module in the sequence and returns it.
It's a fair bit of inheriting, but I think I prefer it to just hoping the user defines an out_dims property on each component that they want to use. And I can imagine us wanting other properties defined on the modules used in RIB.