h2r / pomdp-py

A framework to build and solve POMDP problems. Documentation: https://h2r.github.io/pomdp-py/
MIT License
216 stars 50 forks source link

Sarsop won't compute policy after changing observation model #28

Closed Hororohoruru closed 1 year ago

Hororohoruru commented 1 year ago

Hello. I have a model using SARSOP that was working until now. Following the discussion we had on #27 , I thought about re-running this model with a small change on the observation model's probability() function. The old function was:

def probability(self, observation, next_state, action):
    # The probability of obtaining a new observation knowing the state is given by the discretization / conf matrix
    obs_idx = int(observation.id)
    state_idx = int(next_state.id)
    return self.observation_matrix[state_idx][obs_idx]

And I changed it to:

def probability(self, observation, next_state, action):
    # The probability of obtaining a new observation knowing the state is given by the discretization / conf matrix
    if 'wait' in action.name:
        obs_idx = int(observation.id)
        state_idx = int(next_state.id)
        return self.observation_matrix[state_idx][obs_idx]
    else:
        return 1 / self.n_obs

Then, when I try to compute the SARSOP policy, I get:

-------------------------------------------------------------------------------
 Time   |#Trial |#Backup |LBound    |UBound    |Precision  |#Alphas |#Beliefs  
-------------------------------------------------------------------------------
 0.25    0       0        -99.9996   452.573    552.573     13       1        
ERROR: min_ratio > 1 in upperBoundInternal!
  (min_ratio-1)=4.00002e-06
  normb=1
  b=size: 12,
 data: [0= 0.0833333, 1= 0.0833333, 2= 0.0833333, 3= 0.0833333, 4= 0.0833333, 5= 0.0833333, 6= 0.0833333, 7= 0.0833333, 8= 0.0833333, 9= 0.0833333, 10= 0.0833333, 11= 0.0833333]
  normc=0.999996
  c=size: 12,
 data: [0= 0.083333, 1= 0.083333, 2= 0.083333, 3= 0.083333, 4= 0.083333, 5= 0.083333, 6= 0.083333, 7= 0.083333, 8= 0.083333, 9= 0.083333, 10= 0.083333, 11= 0.083333]
------------------------------------------

What may be the cause of this?

zkytony commented 1 year ago

I’m not exactly sure. Can you check if the observation probabilities sum to 1 using the new model?

Hororohoruru commented 1 year ago

I just checked, it does add up to 1:

decision = BCIAction(2)
next_state = BCIState(3)

obs_model = vep_problem.agent.observation_model

all_observations = obs_model.get_all_observations()
obs_p = [obs_model.probability(obs, next_state, decision) for obs in all_observations]

print(obs_p)
>[0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333,
  0.08333333333333333]

print(sum(obs_p))
> 1.0
zkytony commented 1 year ago

In that case, I’m not really sure. I’ve also run into cases where a small change caused the sarsop solver to not run.

I personally don’t understand the error message. If I am debugging this, I would start by looking at what min_ratio means.

Hororohoruru commented 1 year ago

I personally don't know any C++ in order to check the error myself, so I am stuck...

zkytony commented 1 year ago

It is possible that this is a numerical error.

Looking at the source code in sarsop where this error is produced, it's ok if min_ratio is slightly bigger than 1 (tolerance 1e-10). But in the output you posted, min_ratio is bigger than 1 by 4.00002e-06, which to me is pretty small as well. It could be that some numerical instability issue happened on the python side -- that is in fact quite likely because all your probabilities are 0.833333333333. So when the .pomdpx file gets generated, maybe the text version of the float chops off a few decimals which causes the numerical instability.

When a POMDP agent in pomdp_py gets converted into a .pomdpx file for sarsop, this function gets called that writes the file. You can see that floats are converted to string via %f, which prints only six digits by default. To me, this feels pretty suspicious as the source of error -- as a temporary fix, if you can change the string formatting to allow more digits (for example %.9f would be 9 digits), the problem you're facing might go away.

Hororohoruru commented 1 year ago

Yeah, I was looking at the source code or sarsop and found that the MIN_RARIO_EPS to be defined as 1e-10. I changed it to 1e-4 and it worked. I will try your suggestion as well. Thanks!

zkytony commented 1 year ago

I wouldn't change the source code of the solver, unless you really know what you're doing. You don't want to introduce potential bugs.

Hororohoruru commented 1 year ago

True! I forked the repo to change the precision of the float to str conversion of to_pomdp_file and now everything works. Thanks for your help.

zkytony commented 1 year ago

Nice. Feel free to open a PR about that change. For now, please make the PR against the dev-1.3.3 branch. Thanks!

typo: "not everything works" --> "now everything works" I suppose