flowersteam / lamorel

Lamorel is a Python library designed for RL practitioners eager to use Large Language Models (LLMs).
MIT License
176 stars 14 forks source link

How to load trained Flan-T5 model and then fine tune #17

Closed yanxue7 closed 8 months ago

yanxue7 commented 10 months ago

Hi, I wonder about how to load a trained T5 policy? And then fine-tune it on unseen tasks. I find the parameters of llm_module in PPOUpdater but I can't find how to load_statedict() to llm_module. Could you kindly provide a solution to this problem? Thank you in advance. if kwargs["save_after_update"]: print("Saving model...") torch.save(self._llm_module.state_dict(), kwargs["output_dir"] + "/model.checkpoint") print("Model saved")

yanxue7 commented 10 months ago

I modified the server.py in init() as:

if custom_model_initializer is not None:
            self._model = custom_model_initializer.initialize_model(self._model)
        print('aaaa,config.from_pretrained_ppo:', config.from_pretrained_ppo)
        if config.from_pretrained_ppo is True:  ###@@@@yanxue
            print('''@@@@load from pretrained@@@@''')
            checkpoint_path = config.pretrained_ppo_dir + "/model.checkpoint"
            checkpoint = torch.load(checkpoint_path)

            self._model.load_state_dict(checkpoint)

        if custom_updater is not None:
            self._updater = custom_updater
            self._updater.set_llm_module(
                DDP(self._model, process_group=self._llm_group,
                    find_unused_parameters=config.allow_subgraph_use_whith_gradient),
            )
            assert isinstance(self._updater, BaseUpdater)
        else:
            self._updater = BaseUpdater()
            self._updater.set_llm_module(self._model)
        self.run()

But i meet the error as belew: Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. Error executing job with overrides: ['rl_script_args.path=/home/yanxue/lamorel/examples/PPO_finetuning/main.py'] Traceback (most recent call last): File "/home/yanxue/lamorel/examples/PPO_finetuning/main.py", line 210, in main output = lm_server.custom_module_fns(['score', 'value'], File "/home/yanxue/Grounding/lamorel/lamorel/src/lamorel/caller.py", line 97, in custom_module_fns return self.__call_model(InstructionsEnum.FORWARD, True, module_function_keys=module_function_keys, File "/home/yanxue/Grounding/lamorel/lamorel/src/lamorel/caller.py", line 101, in __call_model dist.gather_object( File "/home/yanxue/anaconda3/envs/dlp/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper return func(*args, *kwargs) File "/home/yanxue/anaconda3/envs/dlp/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2151, in gather_object all_gather(object_size_list, local_size, group=group) File "/home/yanxue/anaconda3/envs/dlp/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper return func(args, **kwargs) File "/home/yanxue/anaconda3/envs/dlp/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2455, in all_gather work.wait() RuntimeError: [../third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer [127.0.1.1]:8600

ClementRomac commented 10 months ago

Hi,

There's no need to modify lamorel to do this.

To load a pretrained LLM, just provide the right path to model_path key in your config file: https://github.com/flowersteam/lamorel/tree/main#setting-up-the-configuration Lamorel will take care of loading the LLM.

To save the newly updated parameters, your code seems right:

if kwargs["save_after_update"]:
print("Saving model...")
torch.save(self._llm_module.state_dict(), kwargs["output_dir"] + "/model.checkpoint")
print("Model saved")

If you want to only load some part of the model, here's an example of how to do it: https://github.com/flowersteam/Grounding_LLMs_with_online_RL/blob/551119d2bee73864cd51567da115ceb05ec35508/experiments/train_language_agent.py#L176-L183

yanxue7 commented 10 months ago

Hi, i want to load the llm policy pretrained on text babyAIText, and i set the model_path as model_path: "/home/yanxue/lamoral/pposmall/model.checkpoint". Then the error message is "huggingface_hub.utils._validators.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/home/yanxue/lamoral/pposmall/model.checkpoint'. Use repo_type argument if needed."

It seems the model_path only suitable of models hugging face? in PPO_fine_tuning/local_gpu_config.yaml

llm_args:
    model_type: seq2seq
    model_path: "/home/yanxue/lamoral/pposmall/model.checkpoint"
    #t5-small
    pretrained: true
    minibatch_size: 192
    pre_encode_inputs: true
    parallelism:
      use_gpu: true
      model_parallelism_size: 1
      synchronize_gpus_after_scoring: false
      empty_cuda_cache_after_scoring: false
yanxue7 commented 10 months ago

I seem solved this problem, i modified the server.py as:

 if custom_updater is not None:
            self._updater = custom_updater
            self._updater.set_llm_module(
                DDP(self._model, process_group=self._llm_group,
                    find_unused_parameters=config.allow_subgraph_use_whith_gradient),
            )
            assert isinstance(self._updater, BaseUpdater)
        else:
            self._updater = BaseUpdater()
            self._updater.set_llm_module(self._model)
        print('aaaa,config.from_pretrained_ppo:', config.from_pretrained_ppo)
        if config.from_pretrained_ppo is True:  ###@@@@yanxue
            print('''@@@@load from pretrained@@@@''')
            checkpoint_path = config.pretrained_ppo_dir + "/model.checkpoint"
            checkpoint = torch.load(checkpoint_path)

            self._updater._llm_module.load_state_dict(checkpoint)
        self.run()

Thank you for your careful response.

ClementRomac commented 10 months ago

Hi,

Well here again you could do the same without modifying lamorel's code. You can load the weights in an Initializer or in an Updater (https://github.com/flowersteam/Grounding_LLMs_with_online_RL/blob/551119d2bee73864cd51567da115ceb05ec35508/experiments/train_language_agent.py#L192C1-L194C85).

The whole point of lUpdaters and Initializer is to give access to the LLM directly, so no need to modify lamorel's source code.

yanxue7 commented 10 months ago

great! thank you!