Closed RoepStoep closed 2 months ago
I suspect the problem originates in the predict/predict_target functions in the ContinuousCritic class from the gbrl repository, specifically in policy_dim = self.output_dim // 2
at ac_gbrl.py:619. It seems a single bias is expected for q_func_type linear/tanh, which suggests that policy_dim = self.output_dim - len(self.bias)
is correct (at least for my use-case). This problem would go unnoticed if the code is tested with action_dim = 1.
Thank you for digging into the code. The latest GBRL library, version 1.0.3, fixes this issue.
First of all I want to thanks for your great work, it's a very interesting repo! I'm currently trying to get the SAC-GBRL implementation to work on one of my projects. As you write in the readme, this is currently a beta implementation, so it is expected there are bugs. I'm encountering a problem in the forward step in policies/sac_policy.py:119, where the dot-product (weights * actions) is calculated.
I'm using a custom environment with action_dim=3, and q_func_type='linear' which gives theta_dim=1, resulting in a q_model gbrl.ContinuousCritic instantiated with output_dim=4. This q_model's predict() call thus outputs a (weights, bias) tuple that both have outer dim 2. This leads to incompatible shapes for the "dot = (weights * actions).sum(dim=1)" calculation, as weights has outer dim 2 while actions has outer dim 3.
I'm not familiar enough with the theoretical background to know how what goes wrong here or how to fix it, so I'm hoping you can help me out here. Thanks a lot if you can find the time!