kvablack / ddpo-pytorch

DDPO for finetuning diffusion models, implemented in PyTorch with LoRA support
MIT License
424 stars 42 forks source link

Batch size unrecogonized #20

Open mao-code opened 7 months ago

mao-code commented 7 months ago

I am using the DDPO logic to fine-tuned my own model. However, I found that the example reward function (LLaVA BERTScore) use a fixed batch size.

After seeing the source code in this repo and the TRL DDPOTrainer class, I found that this batch size may related to sample_batch_size.

I recommend to modify the batch size with the one in the config or leave some comments on it. By doing so, people who wants to design their reward function can have a more sensible guide.

Below is the example reward in this repo I mentioned above.

def llava_bertscore():
    """Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See
    https://github.com/kvablack/LLaVA-server for server-side code.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 16 
    url = "http://127.0.0.1:8085"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata):
        del metadata
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC

        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
...

And this is the code which use compute_reward() in the DDPOTraner class in TRL Repo

def step(self, epoch: int, global_step: int):
        """
        Perform a single step of training.

        Args:
            epoch (int): The current epoch.
            global_step (int): The current global step.

        Side Effects:
            - Model weights are updated
            - Logs the statistics to the accelerator trackers.
            - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.

        Returns:
            global_step (int): The updated global step.

        """
        samples, prompt_image_data = self._generate_samples(
            iterations=self.config.sample_num_batches_per_epoch,
            batch_size=self.config.sample_batch_size,
        )

        # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
        samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
        rewards, rewards_metadata = self.compute_rewards(
            prompt_image_data, is_async=self.config.async_reward_computation
        )
...
mao-code commented 7 months ago

Or is it just the batch for LLaVA to generate responses? Cuz I saw that it finally added all rewards together to a list.