PolymathicAI / multiple_physics_pretraining

Code for paper "Multiple Physics Pretraining for Physical Surrogate Models
MIT License
123 stars 18 forks source link

Instructions on using the pretrained model #3

Open Bucanero06 opened 5 months ago

Bucanero06 commented 5 months ago

Hi, I have prepared the Python environment, and downloaded data from the PDEBench, as well as the model weights from Google Drive. I can train the model from scratch on my current data available and I wanted to know your instructions on how to use the load the pre-trained model. Thank you for your time and contribution.

P.S. I also have questions about the paper MPP 2023 paper, what is your preferred method of communication about the project?

Bucanero06 commented 5 months ago

I have added the path of the MPP_AViT_S tar file to YAML setting pretrained_ckpt_path. From here I arrive to an attribute error when accessing self.model.module

python train_basic.py --run_name first_test_run --config finetune --yaml_config config/mpp_avit_s_config.yaml 
Loading configuration file: multiple_physics_pretraining/config/mpp_avit_s_config.yaml
Configuration name: finetune
Initializing data on rank 0
Initializing model on rank 0
Model parameter count: 28979436
Starting from pretrained model at weights/MPP_AViT_S

Traceback (most recent call last):
  File "multiple_physics_pretraining/train_basic.py", line 547, in <module>
    trainer = Trainer(params, global_rank, local_rank, device, sweep_id=args.sweep_id)
  File "multiple_physics_pretraining/train_basic.py", line 80, in __init__
    self.restore_checkpoint(params.pretrained_ckpt_path)
  File "multiple_physics_pretraining/train_basic.py", line 201, in restore_checkpoint
    self.model.module.unfreeze()
  File "multiple_physics_pretraining/multiple_physics_pretrained_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'AViT' object has no attribute 'module'. Did you mean: 'modules'?

I will look into it so just updating the thread here

Bucanero06 commented 5 months ago

I have added the path of the MPP_AViT_S tar file to YAML setting pretrained_ckpt_path. From here I arrive to an attribute error when accessing self.model.module

python train_basic.py --run_name first_test_run --config finetune --yaml_config config/mpp_avit_s_config.yaml 
Loading configuration file: multiple_physics_pretraining/config/mpp_avit_s_config.yaml
Configuration name: finetune
Initializing data on rank 0
Initializing model on rank 0
Model parameter count: 28979436
Starting from pretrained model at weights/MPP_AViT_S

Traceback (most recent call last):
  File "multiple_physics_pretraining/train_basic.py", line 547, in <module>
    trainer = Trainer(params, global_rank, local_rank, device, sweep_id=args.sweep_id)
  File "multiple_physics_pretraining/train_basic.py", line 80, in __init__
    self.restore_checkpoint(params.pretrained_ckpt_path)
  File "multiple_physics_pretraining/train_basic.py", line 201, in restore_checkpoint
    self.model.module.unfreeze()
  File "multiple_physics_pretraining/multiple_physics_pretrained_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'AViT' object has no attribute 'module'. Did you mean: 'modules'?

I will look into it so just updating the thread here

The issue I was encountering seems to be consistent with trying to access self.model.module when self.model is not wrapped in a DistributedDataParallel (DDP) object; the 'module' attribute is added by the DDP wrapper. It seems the original method assumed the presence of the module attribute even when not using DDP. The updated restore_checkpoint method checks if self.model is an instance of DistributedDataParallel (DDP) and only then access the module attribute; both when loading the state dict and under the self.params.pretrained if statement. Just minor changes but its working on my 1 local GPU workstation.

    def restore_checkpoint(self, checkpoint_path):
        """ Load model/opt from path """
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
        if 'model_state' in checkpoint:
            model_state = checkpoint['model_state']
        else:
            model_state = checkpoint
        try:  # Try to load with DDP Wrapper
            self.model.load_state_dict(model_state)
        except:  # If that fails, either try to load into module or strip DDP prefix
            if isinstance(self.model, DDP):
                self.model.module.load_state_dict(model_state)
            else:
                new_state_dict = OrderedDict()
                for key, val in model_state.items():
                    # Failing means this came from DDP - strip the DDP prefix
                    name = key[7:]
                    new_state_dict[name] = val
                self.model.load_state_dict(new_state_dict)

        if self.params.resuming:  # restore checkpoint is used for finetuning as well as resuming. If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr.
            self.iters = checkpoint['iters']
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.startEpoch = checkpoint['epoch']
            self.epoch = self.startEpoch
        else:
            self.iters = 0

        if self.params.pretrained:
            if isinstance(self.model, DDP):
                model_to_modify = self.model.module
            else:
                model_to_modify = self.model

            if self.params.freeze_middle:
                model_to_modify.freeze_middle()
            elif self.params.freeze_processor:
                model_to_modify.freeze_processor()
            else:
                model_to_modify.unfreeze()

            # See how much we need to expand the projections
            exp_proj = 0
            # Iterate through the appended datasets and add on enough embeddings for all of them.
            for add_on in self.params.append_datasets:
                exp_proj += len(DSET_NAME_TO_OBJECT[add_on]._specifics()[2])
            model_to_modify.expand_projections(exp_proj)

        checkpoint = None
        self.model = self.model.to(self.device)