DigiRL-agent / digirl

Official repo for paper DigiRL: Training In-The-Wild Device-Control Agents with Autonomous Reinforcement Learning.
Apache License 2.0
269 stars 21 forks source link

Question about loss calculation #22

Closed langfengQ closed 1 month ago

langfengQ commented 1 month ago

Hi authors, Thank you for your great work. I have a question regarding the implementation of pg.loss in your code, referenced here: https://github.com/DigiRL-agent/digirl/blob/3a6ef0d2adb9a312052684a0960976d0011a5e99/digirl/models/autoui_agent.py#L114C1-L119C81

        prediction_probs = self.softmax(outputs.logits)
        selected_prediction_probs = torch.take_along_dim(prediction_probs,\
                                                 action_ids["input_ids"].unsqueeze(2), dim=2).squeeze(2)
        selected_prediction_probs = torch.clamp(selected_prediction_probs, min=0.001, max=0.99)
        # import IPython; IPython.embed(); exit()
        return torch.log(selected_prediction_probs)*action_ids["attention_mask"]

My question lies with the line: torch.clamp(selected_prediction_probs, min=0.001, max=0.99)

Based on my observations during the online phase, most of the values in selected_prediction_probs (>99%) are being clipped by this operation, resulting in the loss of gradients. This says that only very few remaining values contribute to the model’s update. Could you provide some insight behind this setting and why reject so much data?

YifeiZhou02 commented 1 month ago

Thanks for spotting this! This is actually not a deliberate design choice except for numerical stability, would be curious to see if it can get to better performance if you improve it!