takuseno / d3rlpy

An offline deep reinforcement learning library
https://takuseno.github.io/d3rlpy
MIT License
1.29k stars 230 forks source link

Customize encoder factory for SAC algorithm error #230

Closed lk1983823 closed 1 year ago

lk1983823 commented 1 year ago

Follow the tutorial https://d3rlpy.readthedocs.io/en/v1.1.1/tutorials/customize_neural_network.html, I try to customize encoders for critic and actor in SAC. I find some errors: First, all the def __init__()should be followed by super(Parentclass_name, self).__init__(), otherwise it will show errors as AttributeError: cannot assign module before Module.__init__() call Second, even I have fixed this, I still encounter new problems:

When I run the following code:

import torch
import torch.nn as nn
from d3rlpy.models.encoders import EncoderFactory

class CustomEncoder(nn.Module):
    def __init__(self, observation_shape, feature_size):
        super(CustomEncoder, self).__init__()
        self.feature_size = feature_size
        self.fc1 = nn.Linear(observation_shape[0], feature_size)
        self.fc2 = nn.Linear(feature_size, feature_size)

    def forward(self, x):
        h = torch.relu(self.fc1(x))
        h = torch.relu(self.fc2(x))
        return h

    # THIS IS IMPORTANT!
    def get_feature_size(self):
        return self.feature_size

class CustomEncoderWithAction(nn.Module):

    def __init__(self, observation_shape, action_size, feature_size):
        super(CustomEncoderWithAction, self).__init__()
        self.feature_size = feature_size
        self.fc1 = nn.Linear(observation_shape[0] + action_size, feature_size)
        self.fc2 = nn.Linear(feature_size, feature_size)
        self.action_size = action_size

    def forward(self, x, action):
        h = torch.cat([x, action], dim=1)
        h = torch.relu(self.fc1(h))
        h = torch.relu(self.fc2(h))
        return h

    def get_feature_size(self):
        return self.feature_size

class CustomEncoderFactory(EncoderFactory):
    TYPE = 'custom' # this is necessary

    def __init__(self, feature_size):
        self.feature_size = feature_size

    def create(self, observation_shape):
        return CustomEncoder(observation_shape, self.feature_size)

    def create_with_action(self, observation_shape, action_size):
        return CustomEncoderWithAction(observation_shape, action_size, self.feature_size)

    def get_params(self, deep=False):
        return {'feature_size': self.feature_size}

encoder_factory = CustomEncoderFactory(5)

sac = d3rlpy.algos.SAC(
                        actor_encoder_factory=encoder_factory,
                        critic_encoder_factory=encoder_factory,
                        actor_learning_rate=3e-4,
                        scaler = state_scaler,
                        action_scaler=action_scaler,
                        reward_scaler=reward_scaler,
                        n_steps=3,
                        critic_learning_rate=3e-4,
                        temp_learning_rate=3e-4,
                        batch_size=256,
                        use_gpu=True)

sac.fit(dataset.episodes,
             eval_episodes=test_episodes,
             n_steps=500000,
             n_steps_per_epoch=1000,
             save_interval=10,
             scorers={
                 'td_errors':  d3rlpy.metrics.td_error_scorer,
#                  "environment": d3rlpy.metrics.evaluate_on_environment(env), # 注意这个env
                 'value_scale': d3rlpy.metrics.average_value_estimation_scorer,
             },
             experiment_name=f"SAC_RZ_temp")

It shows:

2022-10-13 16:52.04 [debug    ] RandomIterator is selected.
2022-10-13 16:52.04 [info     ] Directory is created at d3rlpy_logs/SAC_RZ_temp_20221013165204
2022-10-13 16:52.04 [debug    ] Fitting scaler...              scaler=min_max
2022-10-13 16:52.04 [debug    ] Fitting action scaler...       action_scaler=min_max
2022-10-13 16:52.04 [debug    ] Fitting reward scaler...       reward_scaler=min_max
2022-10-13 16:52.04 [debug    ] Building models...
2022-10-13 16:52.04 [debug    ] Models have been built.
2022-10-13 16:52.04 [info     ] Parameters are saved to d3rlpy_logs/SAC_RZ_temp_20221013165204/params.json params={'action_scaler': {'type': 'min_max', 'params': {'minimum': array([[0.08318865, 0.54615164]], dtype=float32), 'maximum': array([[101.24059 , 100.343605]], dtype=float32)}}, 'actor_encoder_factory': {'type': 'custom', 'params': {'feature_size': 5}}, 'actor_learning_rate': 0.0003, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'batch_size': 256, 'critic_encoder_factory': {'type': 'custom', 'params': {'feature_size': 5}}, 'critic_learning_rate': 0.0003, 'critic_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'gamma': 0.99, 'generated_maxlen': 100000, 'initial_temperature': 1.0, 'n_critics': 2, 'n_frames': 1, 'n_steps': 3, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': {'type': 'min_max', 'params': {'minimum': -40.7203369140625, 'maximum': 7.0, 'multiplier': 1.0}}, 'scaler': {'type': 'min_max', 'params': {'maximum': array([[307.13644 , 105.      ,   3.748553, 437.1429  ]], dtype=float32), 'minimum': array([[ 94.50818  ,  33.43284  ,   1.2541595, 355.27966  ]],
      dtype=float32)}}, 'tau': 0.005, 'temp_learning_rate': 0.0003, 'temp_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'use_gpu': 0, 'algorithm': 'SAC', 'observation_shape': (4,), 'action_size': 2}
Epoch 1/500: 0%
0/1000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [64], in <cell line: 1>()
----> 1 sac.fit(dataset.episodes,
      2              eval_episodes=test_episodes,
      3              n_steps=500000,
      4              n_steps_per_epoch=1000,
      5              save_interval=10,
      6              scorers={
      7                  'td_errors':  d3rlpy.metrics.td_error_scorer,
      8 #                  "environment": d3rlpy.metrics.evaluate_on_environment(env), # 注意这个env
      9                  'value_scale': d3rlpy.metrics.average_value_estimation_scorer,
     10              },
     11              experiment_name=f"SAC_RZ_temp")

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/base.py:406, in LearnableBase.fit(self, dataset, n_epochs, n_steps, n_steps_per_epoch, save_metrics, experiment_name, with_timestamp, logdir, verbose, show_progress, tensorboard_dir, eval_episodes, save_interval, scorers, shuffle, callback)
    349 def fit(
    350     self,
    351     dataset: Union[List[Episode], List[Transition], MDPDataset],
   (...)
    368     callback: Optional[Callable[["LearnableBase", int, int], None]] = None,
    369 ) -> List[Tuple[int, Dict[str, float]]]:
    370     """Trains with the given dataset.
    371 
    372     .. code-block:: python
   (...)
    404 
    405     """
--> 406     results = list(
    407         self.fitter(
    408             dataset,
    409             n_epochs,
    410             n_steps,
    411             n_steps_per_epoch,
    412             save_metrics,
    413             experiment_name,
    414             with_timestamp,
    415             logdir,
    416             verbose,
    417             show_progress,
    418             tensorboard_dir,
    419             eval_episodes,
    420             save_interval,
    421             scorers,
    422             shuffle,
    423             callback,
    424         )
    425     )
    426     return results

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/base.py:637, in LearnableBase.fitter(self, dataset, n_epochs, n_steps, n_steps_per_epoch, save_metrics, experiment_name, with_timestamp, logdir, verbose, show_progress, tensorboard_dir, eval_episodes, save_interval, scorers, shuffle, callback)
    635 # update parameters
    636 with logger.measure_time("algorithm_update"):
--> 637     loss = self.update(batch)
    639 # record metrics
    640 for name, val in loss.items():

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/base.py:748, in LearnableBase.update(self, batch)
    738 def update(self, batch: TransitionMiniBatch) -> Dict[str, float]:
    739     """Update parameters with mini-batch of data.
    740 
    741     Args:
   (...)
    746 
    747     """
--> 748     loss = self._update(batch)
    749     self._grad_step += 1
    750     return loss

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/algos/sac.py:202, in SAC._update(self, batch)
    200 # lagrangian parameter update for SAC temperature
    201 if self._temp_learning_rate > 0:
--> 202     temp_loss, temp = self._impl.update_temp(batch)
    203     metrics.update({"temp_loss": temp_loss, "temp": temp})
    205 critic_loss = self._impl.update_critic(batch)

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/torch_utility.py:313, in train_api.<locals>.wrapper(self, *args, **kwargs)
    311 def wrapper(self: Any, *args: Any, **kwargs: Any) -> np.ndarray:
    312     set_train_mode(self)
--> 313     return f(self, *args, **kwargs)

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/torch_utility.py:295, in torch_api.<locals>._torch_api.<locals>.wrapper(self, *args, **kwargs)
    292             tensor = tensor.float()
    294     tensors.append(tensor)
--> 295 return f(self, *tensors, **kwargs)

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/algos/torch/sac_impl.py:135, in SACImpl.update_temp(self, batch)
    132 self._temp_optim.zero_grad()
    134 with torch.no_grad():
--> 135     _, log_prob = self._policy.sample_with_log_prob(batch.observations)
    136     targ_temp = log_prob - self._action_size
    138 loss = -(self._log_temp().exp() * targ_temp).mean()

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/models/torch/policies.py:199, in NormalPolicy.sample_with_log_prob(self, x)
    196 def sample_with_log_prob(
    197     self, x: torch.Tensor
    198 ) -> Tuple[torch.Tensor, torch.Tensor]:
--> 199     out = self.forward(x, with_log_prob=True)
    200     return cast(Tuple[torch.Tensor, torch.Tensor], out)

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/models/torch/policies.py:189, in NormalPolicy.forward(self, x, deterministic, with_log_prob)
    183 def forward(
    184     self,
    185     x: torch.Tensor,
    186     deterministic: bool = False,
    187     with_log_prob: bool = False,
    188 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
--> 189     dist = self.dist(x)
    190     if deterministic:
    191         action, log_prob = dist.mean_with_log_prob()

File /media/lksgcc/new_disk/lk_git/3_Reinforcement_Learning/3_2_Offline_Learning/d3rlpy/d3rlpy/models/torch/policies.py:171, in NormalPolicy.dist(self, x)
    168 def dist(
    169     self, x: torch.Tensor
    170 ) -> Union[GaussianDistribution, SquashedGaussianDistribution]:
--> 171     h = self._encoder(x)
    172     mu = self._mu(h)
    173     clipped_logstd = self._compute_logstd(h)

File ~/.pyenv/versions/anaconda3-5.0.1/envs/mujoco_py/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

Input In [51], in CustomEncoder.forward(self, x)
      8 def forward(self, x):
      9     h = torch.relu(self.fc1(x))
---> 10     h = torch.relu(self.fc2(x))
     11     return h

File ~/.pyenv/versions/anaconda3-5.0.1/envs/mujoco_py/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.pyenv/versions/anaconda3-5.0.1/envs/mujoco_py/lib/python3.8/site-packages/torch/nn/modules/linear.py:103, in Linear.forward(self, input)
    102 def forward(self, input: Tensor) -> Tensor:
--> 103     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x4 and 5x5)

I have no idea as to how to fix this. I don't think the feature size should be equal to the observation size, or is there some wrong during the encoder building process? In the above example, my observation size is 4. Thanks for the reply!

lk1983823 commented 1 year ago

I now fix this error! In the CustomEncoder class, h = torch.relu(self.fc2(x))should be h = torch.relu(self.fc2(h))

class CustomEncoder(nn.Module):
    def __init__(self, observation_shape, feature_size):
        super(CustomEncoder, self).__init__()
        self.feature_size = feature_size
        self.fc1 = nn.Linear(observation_shape[0], feature_size)
        self.fc2 = nn.Linear(feature_size, feature_size)

    def forward(self, x):
        h = torch.relu(self.fc1(x))
        h = torch.relu(self.fc2(x))
        return h

Suggest to modify the tutorials in the website.

takuseno commented 1 year ago

@lk1983823 Thank you for reporting this! It's fixed in this commit: https://github.com/takuseno/d3rlpy/commit/cd7681ca150d89422f9865daaaa896ead13a7b73 . The change will be reflected to latest documentation. https://d3rlpy.readthedocs.io/en/latest/