OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.71k stars 160 forks source link

Generate function for distributional training #324

Open louieworth opened 3 weeks ago

louieworth commented 3 weeks ago

I am trying to use multiple GPUs in parallel to generate new samples with online generate, but it seems to get stuck and won't move. If I set it to run on a single GPU, it won't get stuck. How can I solve this problem?

323 said that using the vLLM for generate. Could you please give me detailed instructions on how to do this? My generate function is on DPO algorithms to test the generation quality of DPO.

class Actor(nn.Module):
    def __init__():
    ....
    def generate(self, input_ids: torch.Tensor, **kwargs) -> Union[
        Tuple[torch.LongTensor, torch.LongTensor],
        Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor],
    ]:
hijkzzz commented 3 weeks ago

I am trying to use multiple GPUs in parallel to generate new samples with online generate, but it seems to get stuck and won't move. If I set it to run on a single GPU, it won't get stuck. How can I solve this problem?

323 said that using the vLLM for generate. Could you please give me detailed instructions on how to do this? My generate function is on DPO algorithms to test the generation quality of DPO.

class Actor(nn.Module):
    def __init__():
    ....
    def generate(self, input_ids: torch.Tensor, **kwargs) -> Union[
        Tuple[torch.LongTensor, torch.LongTensor],
        Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor],
    ]:

Just remove model.eval() before generation in DPO.

louieworth commented 3 weeks ago

I have removed model.eval() but it does not work. It still transfers to

class Actor(nn.Module):
    def __init__():
    ....
    def generate(self, input_ids: torch.Tensor, **kwargs) -> Union[
        Tuple[torch.LongTensor, torch.LongTensor],
        Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor],
    ]: