We've previously supported running distributed edge calculations. We now support distributed basis builds as well.
We only support this:
for the jacobian basis
When has_pos is true
For dist_split_over=out_dim
n_pods = 1 (we don't have communication between pods setup, which we need for distributing Cs)
We split over the first dimension of the phis array, which represents all of the vectors we want to take our jacobian with. In the no-stochasticty case, there are out_pos * out_hidden different vectors. With stochasticity those factors become n_pos_sources and n_hidden_sources respectively.
We combine across processes by adding the resulting M_dash matrices in data_accumulator.py.
This PR also includes a refactor of the jacobian basis calculation, by:
separating out the phi array generation into a helper function
Changes the phi shape significantly from batch r_hidden r_pos out_hidden out_pos to (r_hidden r_pos) batch out_pos out_hidden. This flattens the two sources direction to facilitate splitting over the dimension, and reorders hidden and pos to match the standard in activations.
Uses pytorch's built in jacobian-vector product calculation.
Motivation and Context
Relevant past PRs for distributed edge calculations include:
196 #319
How Has This Been Tested?
test comparing the basis computed with one process and with two processes.
Distributed calculation for basis
Description
We've previously supported running distributed edge calculations. We now support distributed basis builds as well.
We only support this:
has_pos
is truedist_split_over=out_dim
We split over the first dimension of the
phis
array, which represents all of the vectors we want to take our jacobian with. In the no-stochasticty case, there areout_pos * out_hidden
different vectors. With stochasticity those factors becomen_pos_sources
andn_hidden_sources
respectively.We combine across processes by adding the resulting
M_dash
matrices indata_accumulator.py
.This PR also includes a refactor of the jacobian basis calculation, by:
batch r_hidden r_pos out_hidden out_pos
to(r_hidden r_pos) batch out_pos out_hidden
. This flattens the two sources direction to facilitate splitting over the dimension, and reorders hidden and pos to match the standard in activations.Motivation and Context
Relevant past PRs for distributed edge calculations include:
196 #319
How Has This Been Tested?
Does this PR introduce a breaking change?
No.