araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
335 stars 33 forks source link

SBX becomes super slow when number of cpus are limited #27

Closed Deepakgthomas closed 9 months ago

Deepakgthomas commented 9 months ago

🐛 Bug

SBX becomes much slower than SB3 when the number of cpus are limited

To Reproduce

Steps to reproduce the behavior.

'''
For installation please do -
pip install gym
pip install sbx
pip install mujoco
pip install shimmy
'''
import gym
import psutil
import random
import os, subprocess as sp

def train():
    pid = os.getpid()
    num_of_cpus = 4
    process = psutil.Process(pid)
    print("Process = ", pid)
    affinity = process.cpu_affinity()
    cpus_selected = random.sample(affinity, num_of_cpus)
    print("cpus_selected = ", cpus_selected)
    # print("iteration = ", iteration)
    process.cpu_affinity(cpus_selected)
    env = gym.make("Humanoid-v4")

    model = SAC("MlpPolicy", env, verbose=1)

    model.learn(total_timesteps=7e3, progress_bar=True)

# from stable_baselines3 import SAC

from sbx import SAC

train()

Expected behavior

If you want to compare sb3 vs sbx, you can uncomment from stable_baselines3 import SAC and comment out from sbx import SAC. I am noticing that sb3 is much faster than sbx in such situations

 System Info

Describe the characteristic of your environment:

You can use sb3.get_system_info() to print relevant packages info:

import stable_baselines3 as sb3
sb3.get_system_info()
{'OS': 'Linux-6.5.0-15-generic-x86_64-with-glibc2.17 # 15~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Jan 12 18:54:30 UTC 2', 'Python': '3.8.18', 'Stable-Baselines3': '2.3.0a1', 'PyTorch': '2.1.2+cu121', 'GPU Enabled': 'True', 'Numpy': '1.24.3', 'Cloudpickle': '3.0.0', 'Gymnasium': '0.29.1', 'OpenAI Gym': '0.26.2'}, '- OS: Linux-6.5.0-15-generic-x86_64-with-glibc2.17 # 15~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Jan 12 18:54:30 UTC 2\n- Python: 3.8.18\n- Stable-Baselines3: 2.3.0a1\n- PyTorch: 2.1.2+cu121\n- GPU Enabled: True\n- Numpy: 1.24.3\n- Cloudpickle: 3.0.0\n- Gymnasium: 0.29.1\n- OpenAI Gym: 0.26.2\n')
{'Cloudpickle': '3.0.0',
 'GPU Enabled': 'True',
 'Gymnasium': '0.29.1',
 'Numpy': '1.24.3',
 'OS': 'Linux-6.5.0-15-generic-x86_64-with-glibc2.17 # 15~22.04.1-Ubuntu SMP '
       'PREEMPT_DYNAMIC Fri Jan 12 18:54:30 UTC 2',
 'OpenAI Gym': '0.26.2',
 'PyTorch': '2.1.2+cu121',
 'Python': '3.8.18',
 'Stable-Baselines3': '2.3.0a1'}

Checklist

araffin commented 9 months ago

Hello, could you please provide some information about the CPU you are using and how many cores do you have? Same for the GPU. What is the jaxlib version you are using? You should be using gymnasium instead of gym too.

SBX becomes much slower than SB3 when the number of cpus are limited

in your case, are you sure the gpu is used? how slow is SBX? or is it the compilation time? (if so, https://github.com/araffin/sbx/pull/21 will help)

maybe could you share a run using W&B? For instance: https://wandb.ai/openrlbenchmark/sb3/reports/SB3-vs-SBX--VmlldzoyOTczNDg2

Deepakgthomas commented 9 months ago

My apologies. My program was buggy.