Closed chrisgao99 closed 1 week ago
Thank you so much for the guidance. Now I know how to exclude the external LLM model while saving it.
For other people to understand,
I input a VLM model to the DQN class through policy_kwargs
model = CustomDQN(CustomDQNPolicy, env,verbose=1,tensorboard_log="logs/tb/llmguide_dqn",learning_starts=10000,
policy_kwargs={"vlm_model":VLM,"vlm_model_name":vlm_model_name})
To avoid saving the VLM, I added the 'policy_kwargs' the _excluded_save_params function:
class CustomDQN(DQN):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _excluded_save_params(self):
return super()._excluded_save_params() + ['policy_kwargs']
If you are not sure which part you want to exclude, you can also print all the things you will save before the "save_to_zip_file" here: https://github.com/DLR-RM/stable-baselines3/blob/56c153f048f1035f239b77d1569b240ace83c130/stable_baselines3/common/base_class.py#L866
🐛 Bug
Hello,
I wrote a customized DQN policy trying to use Large Language Model to modify q-value before the dqn policy predicts an action. The idea is simple: every time the agent gets an obs, it querry llm for an expert log prob and add the log prob to the original q_values so that it can choose a better action.
My problem is when sb3 saved best model checkpoint, it will save the external LLM part together and lead to an error. I wonder if there are any method to stop saving the part of my customized dqn policy in checkpoint? I tried to modify the save() in DQN() but it still gives me the same error.
Here is my full code:
The key part is the querry_expert function in the customized QNetwork:
Could anyone give me some suggestions to avoid saving the query_expert() part during saving a checkpoint.
To Reproduce
No response
Relevant log output / Error message
System Info
No response
Checklist