YifeiZhou02 / ArCHer

Research Code for "ArCHer: Training Language Model Agents via Hierarchical Multi-Turn RL"
https://yifeizhou02.github.io/archer.io/
84 stars 10 forks source link

Issues with loading in `lm_optimizer_state_dict` #9

Open starship006 opened 3 months ago

starship006 commented 3 months ago

I am working on a modified version of this repository with slight changes, so I am trying to see if this is an error on my side or not. My setup uses a distributed GPU setup using Accelerate. I am having some issues loading in the lm_optimizer. Here is my current saving and loading code inside of trainer.py:

 def save(self, path):
        torch.save({'model_state_dict': self.accelerator.unwrap_model(self.agent.model).state_dict(),
                    'critic_state_dict': self.accelerator.unwrap_model(self.agent.critic).state_dict(),
                    'target_critic_state_dict': self.accelerator.unwrap_model(self.agent.target_critic).state_dict(),

                    }, path)
        # do it at the same path, but with a different name
        torch.save({'critic_optimizer_state_dict': self.critic_optimizer.state_dict()}, path.replace('.pt', '_critic_optim.pt'))
        torch.save({'lm_optimizer_state_dict': self.lm_optimizer.state_dict()}, path.replace('.pt', '_lm_optim.pt'))

def load(self, path):
        # We've modified the below to load in via the CPU. This fixes a memory issue. The agent will/should be prepared down the line, and the critic/lm optimizer is re-prepared here.
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        self.agent.model.load_state_dict(checkpoint['model_state_dict'])
        self.agent.critic.load_state_dict(checkpoint['critic_state_dict'])
        self.agent.target_critic.load_state_dict(checkpoint['target_critic_state_dict'])

        critic_optim_checkpoint = torch.load(path.replace('.pt', '_critic_optim.pt'))
        self.critic_optimizer.load_state_dict(critic_optim_checkpoint['critic_optimizer_state_dict'])    

        # The following crashes
        #trainer_checkpoint = torch.load(path.replace('.pt', '_lm_optim.pt'))
        #self.lm_optimizer.load_state_dict(trainer_checkpoint['lm_optimizer_state_dict'])

        self.critic_optimizer, self.lm_optimizer = self.accelerator.prepare(self.critic_optimizer, self.lm_optimizer)        
        return self.agent

The code above works fine, but isn't loading in lm_optimizer. However, when uncommenting those lines of code, everything works until self.lm_optimizer tries to perform lm_optimizer.step(). The code errors with:

RuntimeError: Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding

I'm currently pretty lost as to what the bug might be. I don't think I've changed any code which would be relevant to lm_optimizer. If this is something that you recognize/notice, I would very much appreciate it!

YifeiZhou02 commented 3 months ago

I'm sorry I don't think I have run into any issue like this before. Do you think it might be because somewhere during loading and checkpointing the device is messed up? Is it that only thelm_optimizer is crashing but critic_optimizer works fine? It sounds weird to me too.

starship006 commented 3 months ago

Yup, critic_optimizer works but lm_optimizer crashes when stepping. Might be worth noting that we are currently trying to use bfloat16? But so far, mostly unsure about whats going on currently. Might check and see if this replicates on this repo itself