aqlaboratory / openfold

Trainable, memory-efficient, and GPU-friendly PyTorch reproduction of AlphaFold 2
Apache License 2.0
2.71k stars 503 forks source link

Finetune from AF_Multimer parameters #426

Open tpdmskim opened 4 months ago

tpdmskim commented 4 months ago

Hi, I am encountering an issue similar to issue #421 where I attempted to fine-tune a model using the AF_multimer pretrained parameters (resume_from_jax_params: params_model_1_multimer_v3.npz). However, the loss is unusually high, and it seems as if the model is starting from initial training, not utilizing the pretrained parameters.

Here are the arguments I used: python ~~~/train_openfold.py \


    ~~~/train/alignment_dir \
    ~~~/data/template_mmcif_dir/ \
    ~~~/save/ \
    2021-09-30 \
    --train_mmcif_data_cache_path ~~~/train_mmcif_test.json \
    --precision bf16 \
    --val_data_dir ~~~/valid/data_dir \
    --val_alignment_dir ~~~/valid/alignment_dir \
    --val_mmcif_data_cache_path ~~~/valid_mmcif_test.json \
    --kalign_binary_path ~~~/bin/kalign \
    --obsolete_pdbs_file_path ~~~/pdb_mmcif/obsolete.dat \
    --template_release_dates_cache_path ~~~/template_mmcif_cache.json \
    --seed 622 \
    --replace_sampler_ddp True \
    --checkpoint_every_epoch \
    --resume_from_jax_params ~~~/openfold/resources/params/params_model_1_multimer_v3.npz \
    --log_performance True \
    --script_modules False \
    --train_epoch_len 200 \
    --log_lr \
    --config_preset "model_1_multimer_v3" \
    --gpus 1 \
    --num_processes 16 \
    --strategy ddp \
    --num_nodes=1 \
    --deepspeed_config ~~~/deepspeed_config.json

I suspect that in trainer.py, the training might be initializing from scratch due to the following lines:

 trainer.fit(
    model_module, 
    datamodule=data_module,
    ckpt_path=ckpt_path  # ckpt_path is None since I use args.resume_from_jax
)

  Could this be why the model seems to be training from scratch rather than fine-tuning? If so, how can I convert the .npz file(params_model_1_multimer_v3.npz) to .ckpt format? Is there a script available for this conversion? Thank you for your assistance.
C-de-Furina commented 2 months ago

Hi, have you solve this problem? I tried to print weight&bias from original AF parameters and ckpt after only one step train. Almost all weights changed sharply and all bias are, even broken. I do not know why are these bias in e^-6 or e^-7 but they are definitely wrong. image image

jasonkim8652 commented 2 months ago

@C-de-Furina Hi, I am suffering from same problem. I'm curious how did you make a ckpt format file from alphafold2 jax parameter. There is no script to convert jax parameter into openfold ckpt parameter in this repository. I also want to know how did you run the fine-tuning? Was it same in last issue you made? Thank you for your help!

C-de-Furina commented 2 months ago

@C-de-Furina 你好,我也遇到了同样的问题。我很好奇你是如何从 alphafold2 jax 参数制作出 ckpt 格式文件的。这个存储库中没有将 jax 参数转换为 openfold ckpt 参数的脚本。我也想知道你是如何运行微调的?你上次做的问题也是这样吗?谢谢你的帮助!

I'm still tring to find a way to create ckpt. Now I try to call "trainer.save_checkpoint("example.ckpt")" without any training but meet some troubles. Before this I created a ckpt by training from jax after only one step. So I can say, these parameters are broken at beginning. My training is in the same way as monomer except "--config_preset="model_5_multimer_v3""

Besides, you can try to use your trained ckpt to predict a homomer protein, like 8w7d, then you will see all chains almost completely overlaps, but this situation will not occur for heteromers. Thus, I think maybe there is some intrinsic defects in OF-multimer training.

C-de-Furina commented 2 months ago

Well I understand what happens now. The main problem is that when you load from jax, ExponentialMovingAverage, the ema initiallized incorrectly. You can see in train_openfold.py line307, there is:

if args.resume_from_ckpt:
        if args.resume_model_weights_only:
            # Load the checkpoint
            if os.path.isdir(args.resume_from_ckpt):
                sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
            # Process the state dict
            if 'module' in sd:
                sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            elif 'state_dict' in sd:
                import_openfold_weights_(
                    model=model_module, state_dict=sd['state_dict'])
            else:
                # Loading from pre-trained model
                sd = {'model.'+k: v for k, v in sd.items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            logging.info("Successfully loaded model weights...")

Look, the variable 'state_dict' stored params used in ema. But if you check line 262:

def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        logging.warning('load from version'+model_version)
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )

There is nothing about state_dict, so ema can only use original parameters which come from AF model initialization, and this is not trained.

There should be some better methods but I'm not a professional programmer, so I fix this by:

1.Add an attribute to ema

def __init__(self, model: nn.Module, decay: float):
        """
        Args:
            model:
                A torch.nn.Module whose parameters are to be tracked
            decay:
                A value (usually close to 1.) by which updates are
                weighted as part of the above formula
        """
        super(ExponentialMovingAverage, self).__init__()
        clone_param = lambda t: t.clone().detach()
        self.params = tensor_tree_map(clone_param, model.state_dict())
        self.decay = decay
        self.device = next(model.parameters()).device
        self.load_from_jax = False

2.Add a method to ema

def repair_params_when_load_from_jax(self, state_dict: OrderedDict) -> None:
        for k in state_dict.keys():
            self.params[k] = state_dict[k].clone()

3.Run this method once when load_from_jax is true and set it to false

def update(self, model: torch.nn.Module) -> None:
        """
        Updates the stored parameters using the state dict of the provided
        module. The module should have the same structure as that used to
        initialize the ExponentialMovingAverage object.
        """
        if self.load_from_jax:
            self.load_from_jax = False
            self.repair_params_when_load_from_jax(model.state_dict())
        self._update_state_dict_(model.state_dict(), self.params)

4.When you load from jax, set load_from_jax to true

def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        logging.warning('load from version'+model_version)
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )
        self.ema.load_from_jax = True
jasonkim8652 commented 2 months ago

Thank you for your kind explanation! I would try your solution and check whether problem solved.

abhinavb22 commented 1 month ago

Well I understand what happens now. The main problem is that when you load from jax, ExponentialMovingAverage, the ema initiallized incorrectly. You can see in train_openfold.py line307, there is:

if args.resume_from_ckpt:
        if args.resume_model_weights_only:
            # Load the checkpoint
            if os.path.isdir(args.resume_from_ckpt):
                sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
            # Process the state dict
            if 'module' in sd:
                sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            elif 'state_dict' in sd:
                import_openfold_weights_(
                    model=model_module, state_dict=sd['state_dict'])
            else:
                # Loading from pre-trained model
                sd = {'model.'+k: v for k, v in sd.items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            logging.info("Successfully loaded model weights...")

Look, the variable 'state_dict' stored params used in ema. But if you check line 262:

def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        logging.warning('load from version'+model_version)
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )

There is nothing about state_dict, so ema can only use original parameters which come from AF model initialization, and this is not trained.

There should be some better methods but I'm not a professional programmer, so I fix this by:

1.Add an attribute to ema

def __init__(self, model: nn.Module, decay: float):
        """
        Args:
            model:
                A torch.nn.Module whose parameters are to be tracked
            decay:
                A value (usually close to 1.) by which updates are
                weighted as part of the above formula
        """
        super(ExponentialMovingAverage, self).__init__()
        clone_param = lambda t: t.clone().detach()
        self.params = tensor_tree_map(clone_param, model.state_dict())
        self.decay = decay
        self.device = next(model.parameters()).device
        self.load_from_jax = False

2.Add a method to ema

def repair_params_when_load_from_jax(self, state_dict: OrderedDict) -> None:
        for k in state_dict.keys():
            self.params[k] = state_dict[k].clone()

3.Run this method once when load_from_jax is true and set it to false

def update(self, model: torch.nn.Module) -> None:
        """
        Updates the stored parameters using the state dict of the provided
        module. The module should have the same structure as that used to
        initialize the ExponentialMovingAverage object.
        """
        if self.load_from_jax:
            self.load_from_jax = False
            self.repair_params_when_load_from_jax(model.state_dict())
        self._update_state_dict_(model.state_dict(), self.params)

4.When you load from jax, set load_from_jax to true

def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        logging.warning('load from version'+model_version)
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )
        self.ema.load_from_jax = True

I probably had this issue while running training using a very small batch. The trained params would give me a structure where everything is at the origin. Your solution works for me and now I get well folded proteins. However, it appears that the loss is still really high. Is it simply because my batch size is very small (5 proteins) or is something else wrong? I am loading model_1_multimer_v3, and training with 5 proteins whose inference is very accurate. So I expect that the loss should be low (does that make sense??) Were you able to train any models for multimer yet?

C-de-Furina commented 1 month ago

Well I understand what happens now. The main problem is that when you load from jax, ExponentialMovingAverage, the ema initiallized incorrectly. You can see in train_openfold.py line307, there is:

if args.resume_from_ckpt:
        if args.resume_model_weights_only:
            # Load the checkpoint
            if os.path.isdir(args.resume_from_ckpt):
                sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
            # Process the state dict
            if 'module' in sd:
                sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            elif 'state_dict' in sd:
                import_openfold_weights_(
                    model=model_module, state_dict=sd['state_dict'])
            else:
                # Loading from pre-trained model
                sd = {'model.'+k: v for k, v in sd.items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            logging.info("Successfully loaded model weights...")

Look, the variable 'state_dict' stored params used in ema. But if you check line 262:

def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        logging.warning('load from version'+model_version)
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )

There is nothing about state_dict, so ema can only use original parameters which come from AF model initialization, and this is not trained. There should be some better methods but I'm not a professional programmer, so I fix this by: 1.Add an attribute to ema

def __init__(self, model: nn.Module, decay: float):
        """
        Args:
            model:
                A torch.nn.Module whose parameters are to be tracked
            decay:
                A value (usually close to 1.) by which updates are
                weighted as part of the above formula
        """
        super(ExponentialMovingAverage, self).__init__()
        clone_param = lambda t: t.clone().detach()
        self.params = tensor_tree_map(clone_param, model.state_dict())
        self.decay = decay
        self.device = next(model.parameters()).device
        self.load_from_jax = False

2.Add a method to ema

def repair_params_when_load_from_jax(self, state_dict: OrderedDict) -> None:
        for k in state_dict.keys():
            self.params[k] = state_dict[k].clone()

3.Run this method once when load_from_jax is true and set it to false

def update(self, model: torch.nn.Module) -> None:
        """
        Updates the stored parameters using the state dict of the provided
        module. The module should have the same structure as that used to
        initialize the ExponentialMovingAverage object.
        """
        if self.load_from_jax:
            self.load_from_jax = False
            self.repair_params_when_load_from_jax(model.state_dict())
        self._update_state_dict_(model.state_dict(), self.params)

4.When you load from jax, set load_from_jax to true

def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        logging.warning('load from version'+model_version)
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )
        self.ema.load_from_jax = True

I probably had this issue while running training using a very small batch. The trained params would give me a structure where everything is at the origin. Your solution works for me and now I get well folded proteins. However, it appears that the loss is still really high. Is it simply because my batch size is very small (5 proteins) or is something else wrong? I am loading model_1_multimer_v3, and training with 5 proteins whose inference is very accurate. So I expect that the loss should be low (does that make sense??) Were you able to train any models for multimer yet?

That‘s usual. Initial loss is always high and will decrease fast after hundreds of step. Generally you can get loss lower 80.