dptech-corp / Uni-Fold

An open-source platform for developing protein models beyond AlphaFold.
https://doi.org/10.1101/2022.08.04.502811
Apache License 2.0
380 stars 74 forks source link

add alphafold v3 param parse & config #122

Closed ZiyaoLi closed 1 year ago

lhatsk commented 1 year ago

AFAICT the parameters are never swapped, isn't there something missing here?

guolinke commented 1 year ago

@lhatsk

image
lhatsk commented 1 year ago

But LinearSwapParams doesn't do anything differently, does it? https://github.com/dptech-corp/Uni-Fold/blob/main/scripts/translate_jax_params.py#L174

index and swap are never used. It's the same as LinearParams.

ZiyaoLi commented 1 year ago

@teslacool plz check if current implementation is correct and necessary.

lhatsk commented 1 year ago

I was expecting something along this:

    LinearSwapParams = lambda l, index: {
        "weights": LinearWeight(torch.cat([l.weight[index:], l.weight[:index]])),
        "bias": LinearBias(torch.cat([l.bias[index:], l.bias[:index]])),
    }

Tbf, I also didn't do this and my guess is the network can correct it. Both versions behave very weirdly at the beginning of fine-tuning though. Large initial loss, my guess is I am missing something.

teslacool commented 1 year ago

@ZiyaoLi pls review this PR.