Closed ZiyaoLi closed 1 year ago
@lhatsk
But LinearSwapParams doesn't do anything differently, does it? https://github.com/dptech-corp/Uni-Fold/blob/main/scripts/translate_jax_params.py#L174
index and swap are never used. It's the same as LinearParams.
@teslacool plz check if current implementation is correct and necessary.
I was expecting something along this:
LinearSwapParams = lambda l, index: {
"weights": LinearWeight(torch.cat([l.weight[index:], l.weight[:index]])),
"bias": LinearBias(torch.cat([l.bias[index:], l.bias[:index]])),
}
Tbf, I also didn't do this and my guess is the network can correct it. Both versions behave very weirdly at the beginning of fine-tuning though. Large initial loss, my guess is I am missing something.
AFAICT the parameters are never swapped, isn't there something missing here?