ConvLab / ConvLab-3

Apache License 2.0
104 stars 30 forks source link

[BUG] RL training tutorial #179

Open NoB0 opened 10 months ago

NoB0 commented 10 months ago

Describe the bug The script example_train.py in Train RL Policies does not run.

To Reproduce Steps to reproduce the behavior:

cd tutorials/Train_RL_Policies
python example_train.py

Error:

Traceback (most recent call last):
  File "convlab/tutorials/Train_RL_Policies/example_train.py", line 166, in <module>
    policy_sys = PPO(True)
  File "convlab/policy/ppo/ppo.py", line 49, in __init__
    self.vector = VectorBinary(dataset_name=kwargs['dataset_name'],
KeyError: 'dataset_name'

Expected behavior The script should train a dialogue policy.

Actual behavior The script fails to run.

Additional context I guess that the example should run with the multiwoz21 dataset. Thus, I modified example_train.py as shown below:

policy_sys = PPO(True, dataset_name='multiwoz21')

But then, it seems that the state does not have the correct input size, see:

Traceback (most recent call last):
  File "/home/stud/nbernard/miniconda3/envs/convlab/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/stud/nbernard/miniconda3/envs/convlab/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/tutorials/Train_RL_Policies/example_train.py", line 60, in sampler
    s_vec = torch.Tensor(policy.vector.state_vectorize(s))
ValueError: expected sequence of length 361 at dim 1 (got 208)

Any help on this matter would be appreciated.

zqwerty commented 8 months ago

Sorry for the late reply! @ChrisGeishauser Could you have a look at this issue? how to train RL policy

Ahmed-Mahmod-Salem commented 6 months ago

it seems like a simple indexing error, the method

policy.vector.state_vectorize(s) returns the state embedding along with a mask, a simple fix can be just taking the 0th index s_vec = torch.Tensor(policy.vector.state_vectorize(s)[0])

although I am not sure what the mask does. it seems like the example file and the ppo policy implementation are out of sync