I found that an additional dimension representing the number of iterations was added to the input data. I'm wondering what is the purpose of this dimension? Will all model parameters also increase in dimension during training? I did not see this kind of data processing in the original JAX version of AlphaFold.
I found that an additional dimension representing the number of iterations was added to the input data. I'm wondering what is the purpose of this dimension? Will all model parameters also increase in dimension during training? I did not see this kind of data processing in the original JAX version of AlphaFold.