ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
2 stars 0 forks source link

Create Sequential Component ABC with module properties #231

Closed danbraunai-apollo closed 9 months ago

danbraunai-apollo commented 10 months ago

There are two places in the code where we need to know the dimensions of the input/output of a module. One is in in_tuple_dims in hook_fns.interaction_edge_pre_forward_hook_fn, the other is final_node_dim in calculate_interaction_rotations (oh I forgot, there is also initialising inner_token_sums in integrated_gradient_trapezoidal_jacobian_squared

In (many cases within) both of these, we calculate the out_dim by passing data through the modules and measuring the output. This is gross. Instead, we can do the following:

Create a RIBModule or SequentialModule class which inherits from nn.Module and has an out_dims property that must be set (can make it also inherit from ABC to enforce setting of the out_dim property).

We make all our components inherit from RIBModule/SequentialModule and define that property for themselves.

We add a property method to the MultiSequental class called "out_dims", which just runs self.out_dims on the final module in the sequence and returns it.

It's a fair bit of inheriting, but I think I prefer it to just hoping the user defines an out_dims property on each component that they want to use. And I can imagine us wanting other properties defined on the modules used in RIB.

danbraunai-apollo commented 10 months ago

Just realised it's in_tuple_dims not out_tuple_dims. We can calculate in_tuple_dims based on the input to the hook_fn, so the feature proposed above won't be useful for that.

The above feature only seems useful for getting the dimension of the logits, as well as getting the out_pos_size argument that's used for initialising inner_token_sums inside integrated_gradient_trapezoidal_jacobian_squared.

This isn't that useful. I'm going to downgrade to low priority.

stefan-apollo commented 10 months ago

We would also like this functionality to remove the in_tuple_dims from integrated_gradient_trapezoidal_jacobian_functional and integrated_gradient_trapezoidal_jacobian_squared and partial-ing it into module_hat_partial

danbraunai-apollo commented 10 months ago

We're also interested in knowing which dimensions are bias dimensions in the input and output of modules. This is needed for mean centering https://github.com/ApolloResearch/rib/pull/249/

danbraunai-apollo commented 9 months ago

This turned out to be quite messy to implement, perhaps messier than the problem it solves (see #272). For that reason, we're going to stick with the original solution.