DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.38k stars 1.61k forks source link

[Feature Request] Enable predict to take tensor as input #1896

Closed llewynS closed 2 months ago

llewynS commented 2 months ago

šŸš€ Feature

Enable predict to take tensors so that tensors don't have to be offloaded to cpu when feeding into predict.

BasePolicy.predict currently accepts Union[np.ndarray, Dict[str, np.ndarray]]

I propose that accepted inputs change to Union[np.ndarray, Dict[str, np.ndarray], torch.Tensor, Dict[str, torch.Tensor]]

Motivation

predict currently only accepts numpy arrays this is inefficient if using the stable baselines model with custom environments/training loops whereby there is no need to constantly offload GPU memory to CPU only for it to be reloaded into GPU using SB3 agents.

Pitch

A performance optimisation whereby, sb3agent.predict(mytorchtensor) works

Alternatives

The alternative is to keep it the same requiring the transfer between cpu and gpu which is slow.

Additional context

If this was extended to enabling the replay to store tensors that could also be a boon. There isn't a lot of need to store everything off the GPU now that GPU memory is so large...

Checklist

araffin commented 2 months ago

You can already do that using the forward() method of the policy (or the _predict()): https://github.com/DLR-RM/stable-baselines3/blob/5623d98f9d6bcfd2ab450e850c3f7b090aef5642/stable_baselines3/common/policies.py#L365-L368

See https://stable-baselines3.readthedocs.io/en/master/guide/export.html#export-to-onnx and https://github.com/DLR-RM/stable-baselines3/issues/568 (and https://github.com/DLR-RM/stable-baselines3/issues/385)

EDIT: if you just want to use predict, you probably need to put the model on cpu and not use the GPU at all

llewynS commented 2 months ago

Thanks mate, that appears to have worked. I opted for:

with torch.no_grad():
    mymodel.policy._predict(inp_dict)

I probably should have read past line 365 and not been so hasty.

llewynS commented 2 months ago

Feature not required