hakuhodo-technologies / scope-rl

SCOPE-RL: A python library for offline reinforcement learning, off-policy evaluation, and selection
https://scope-rl.readthedocs.io/en/latest/
Apache License 2.0
105 stars 10 forks source link

Query on Handling Offline Data with scope-rl #25

Open jupitersh opened 5 months ago

jupitersh commented 5 months ago

I am working in the field of reinforcement learning research, particularly in medical applications.

My inquiry is about using pre-collected offline data (encompassing state, action, next state, and reward) for constructing a logged_dataset in scope-rl. I noticed that the documentation mostly focuses on simulated data. However, my dataset is offline and pre-collected, and I'm unsure about the correct approach to define pscore in the logged_dataset for such data.

Could you provide guidance or share best practices on how to manage pscore for offline datasets within scope-rl? Your input would be highly valuable and greatly assist my research in the medical domain.

pmoran3 commented 5 months ago

I am also struggling with setting the pscores using real-world data. More tutorials and/or documentation for these cases would be greatly appreciated!

aiueola commented 5 months ago

Hi @jupitersh and @pmoran3,

Thank you for the question.

When the "pscore" is not recorded in the logged data, I recommend using Marginal OPE estimators (e.g., Uehara et al., 20). These estimators first estimate the marginal probability given state and action, and apply importance sampling using the estimated marginal probability.

The marginal probability is calculated in the CreateOPEInput class by calling "obtain_whole_inputs" (and by setting "pscore" in the "logged_feedback" to None). Please set "require_weight_prediction=True" and specify the method to estimate marginal importance weight using the "w_function_method" argument.

For general instructions and formatting requirements in using real-worl data, please also refer to documentation for handling real world data.

I hope this information will be helpful to you.

pmoran3 commented 5 months ago

@aiueola Thank you for this information. Does this also apply to continuous real-world data?

When I set "pscore" to None and call "obtain_whole_inputs," I get an error:

[/usr/local/lib/python3.10/dist-packages/scope_rl/ope/input.py](https://localhost:8080/#) in _register_logged_dataset(self, logged_dataset)
    388         )
    389         self.reward_2d = self.reward.reshape((-1, self.step_per_trajectory))
--> 390         self.pscore_2d = self.pscore.reshape((-1, self.step_per_trajectory))
    391         self.done_2d = self.done.reshape((-1, self.step_per_trajectory))
    392         self.terminal_2d = self.terminal.reshape((-1, self.step_per_trajectory))

AttributeError: 'NoneType' object has no attribute 'reshape'

This is how I am initializing the logged_dataset dict and calling obtain_whole_inputs:

test_logged_dataset = {
  "size":100000,
  "step_per_trajectory":10,
  "n_trajectories":10000,
  "action":actions_test,
  "state":observations_test,
  "reward":rewards_test,
  "action_type":"continuous",
  "n_actions":None,
  "action_meaning":None,
  "state_dim":3,
  "done":terminals,
  "terminal":terminals,
  "random_state":random_state,
  "action_dim":2,
  "behavior_policy":None,
  "dataset_id":0,
  "pscore":None,
  "info":None
}
prep = CreateOPEInput()

input_dict = prep.obtain_whole_inputs(
    logged_dataset=test_logged_dataset,
    evaluation_policies=evaluation_policies,
    require_weight_prediction=True,
    require_value_prediction=True,
    w_function_method="dice",
    n_trajectories_on_policy_evaluation=100,
    random_state=random_state,
)
ericyue commented 5 months ago

@aiueola could you provide a more detail jupyter notebook about how to load a custom logged data (without pscore) to train a BCQ (or others) model? it will be very helpful!

pmoran3 commented 5 months ago

@jupitersh Were you able to calculate the pscores properly for your problem? I am still having issues.

ericyue commented 4 months ago

@jupitersh have you solve this problem? any idea will be helpful