Eclectic-Sheep / sheeprl

Distributed Reinforcement Learning accelerated by Lightning Fabric
https://eclecticsheep.ai
Apache License 2.0
322 stars 33 forks source link

Is there an easier way to test checkpoint? #318

Closed chrisgao99 closed 1 month ago

chrisgao99 commented 2 months ago

Hello sheeprl team,

for the evaluation of checkpoints, it's great but perhaps too complicated for a beginner to really utillize. For example, I want to do a simple test of a dreamer_v3 ckpt trained in my custom env like this:

model = load(ckpt_path)
obs,info = env.reset(seed=test_seed)
while not done:
    action = model(obs)
    obs, reward, done, truncated, info = env.step(action)
    save_video_frame(obs)  

I can't find any code that looks like this from the repo and my env doesn't support recorded_frame attribute like atari. So may I ask if there is an easy way like this to do evaluation and I can save the frames from the obs in every step.

And if i want to learn and better utilize the current evaluation loop in sheeprl. Could you give me a good place to start? Thanks a lot.

michele-milesi commented 2 months ago

Hi @chrisgao99, thanks for your question. SheepRL provides the evaluation of a checkpoint through the sheeprl-eval command, or the sheeprl_eval.py script. The script calls the evaluate() function in the evaluate.py file of the algorithm (in your case, ./sheeprl/algos/dreamer_v3/evaluate.py). The evaluate() function instantiates the environment and the agent and then calls the test function of the algorithm (placed in the ./sheeprl/algos/dreamer_v3/utils.py file).

If your environment implements the render function with the rgb_array mode available, you can save the video of your evaluation by setting the env.capture_video argument to True when calling the evaluation script. For more information, check the how to section.

If you want to save the video frames for other reasons, you can find the while loop in the test() function in the ./sheeprl/algos/dreamer_v3/utils.py file, and you can change that function. IMPORTANT If you want to modify the test() function, you need to install sheeprl with the pip install -e .[<optional deps>] command (before this, make sure to uninstall sheeprl with pip uninstall sheeprl -y).

I hope everything is clear, I await your feedback. Thanks

chrisgao99 commented 2 months ago

Thank you so much for the instructions on evaluation. It helps a lot.

May I ask another question about training? I tried training atari pong with dreamer_v3, but it takes 20 hours for 345600 steps. The device I'm using is one rtx4090 GPU and it usually takes 6200MB gpu memory and 6.8GB system memory. Is this speed normal? I know the 345600 means the steps interacting with the real env, not including interactions in the world model. So, I'm not sure if this is too slow. But the reward is around -20 after 20 hours of training, which is really bad for the pong game.

Mostly I used the default configs. The configs I changed are

dreamer_v3_M
env
screen_size: 64, frame_skip: 1, action_repeat: 1, num_envs:8 
fabric
accelerator: gpu

I've also tried num_envs=1, but it's also slow.

Could you help me figure out where might go wrong? Truly appreciate your help. Here is the full config I use:

CONFIG
├── algo
│   └── name: dreamer_v3                                                        
│       total_steps: 1600000                                                    
│       per_rank_batch_size: 16                                                 
│       run_test: true                                                          
│       cnn_keys:                                                               
│         encoder:                                                              
│         - rgb                                                                 
│         decoder:                                                              
│         - rgb                                                                 
│       mlp_keys:                                                               
│         encoder: []                                                           
│         decoder: []                                                           
│       world_model:                                                            
│         optimizer:                                                            
│           _target_: torch.optim.Adam                                          
│           lr: 0.0001                                                          
│           eps: 1.0e-08                                                        
│           weight_decay: 0                                                     
│           betas:                                                              
│           - 0.9                                                               
│           - 0.999                                                             
│         discrete_size: 32                                                     
│         stochastic_size: 32                                                   
│         kl_dynamic: 0.5                                                       
│         kl_representation: 0.1                                                
│         kl_free_nats: 1.0                                                     
│         kl_regularizer: 1.0                                                   
│         continue_scale_factor: 1.0                                            
│         clip_gradients: 1000.0                                                
│         decoupled_rssm: false                                                 
│         learnable_initial_recurrent_state: true                               
│         encoder:                                                              
│           cnn_channels_multiplier: 48                                         
│           cnn_act: torch.nn.SiLU                                              
│           dense_act: torch.nn.SiLU                                            
│           mlp_layers: 3                                                       
│           cnn_layer_norm:                                                     
│             cls: sheeprl.models.models.LayerNormChannelLast                   
│             kw:                                                               
│               eps: 0.001                                                      
│           mlp_layer_norm:                                                     
│             cls: sheeprl.models.models.LayerNorm                              
│             kw:                                                               
│               eps: 0.001                                                      
│           dense_units: 640                                                    
│         recurrent_model:                                                      
│           recurrent_state_size: 1024                                          
│           layer_norm:                                                         
│             cls: sheeprl.models.models.LayerNorm                              
│             kw:                                                               
│               eps: 0.001                                                      
│           dense_units: 640                                                    
│         transition_model:                                                     
│           hidden_size: 640                                                    
│           dense_act: torch.nn.SiLU                                            
│           layer_norm:                                                         
│             cls: sheeprl.models.models.LayerNorm                              
│             kw:                                                               
│               eps: 0.001                                                      
│         representation_model:                                                 
│           hidden_size: 640                                                    
│           dense_act: torch.nn.SiLU                                            
│           layer_norm:                                                         
│             cls: sheeprl.models.models.LayerNorm                              
│             kw:                                                               
│               eps: 0.001                                                      
│         observation_model:                                                    
│           cnn_channels_multiplier: 48                                         
│           cnn_act: torch.nn.SiLU                                              
│           dense_act: torch.nn.SiLU                                            
│           mlp_layers: 3                                                       
│           cnn_layer_norm:                                                     
│             cls: sheeprl.models.models.LayerNormChannelLast                   
│             kw:                                                               
│               eps: 0.001                                                      
│           mlp_layer_norm:                                                     
│             cls: sheeprl.models.models.LayerNorm                              
│             kw:                                                               
│               eps: 0.001                                                      
│           dense_units: 640                                                    
│         reward_model:                                                         
│           dense_act: torch.nn.SiLU                                            
│           mlp_layers: 3                                                       
│           layer_norm:                                                         
│             cls: sheeprl.models.models.LayerNorm                              
│             kw:                                                               
│               eps: 0.001                                                      
│           dense_units: 640                                                    
│           bins: 255                                                           
│         discount_model:                                                       
│           learnable: true                                                     
│           dense_act: torch.nn.SiLU                                            
│           mlp_layers: 3                                                       
│           layer_norm:                                                         
│             cls: sheeprl.models.models.LayerNorm                              
│             kw:                                                               
│               eps: 0.001                                                      
│           dense_units: 640                                                    
│       actor:                                                                  
│         optimizer:                                                            
│           _target_: torch.optim.Adam                                          
│           lr: 8.0e-05                                                         
│           eps: 1.0e-05                                                        
│           weight_decay: 0                                                     
│           betas:                                                              
│           - 0.9                                                               
│           - 0.999                                                             
│         cls: sheeprl.algos.dreamer_v3.agent.Actor                             
│         ent_coef: 0.0003                                                      
│         min_std: 0.1                                                          
│         max_std: 1.0                                                          
│         init_std: 2.0                                                         
│         dense_act: torch.nn.SiLU                                              
│         mlp_layers: 3                                                         
│         layer_norm:                                                           
│           cls: sheeprl.models.models.LayerNorm                                
│           kw:                                                                 
│             eps: 0.001                                                        
│         dense_units: 640                                                      
│         clip_gradients: 100.0                                                 
│         unimix: 0.01                                                          
│         action_clip: 1.0                                                      
│         moments:                                                              
│           decay: 0.99                                                         
│           max: 1.0                                                            
│           percentile:                                                         
│             low: 0.05                                                         
│             high: 0.95                                                        
│       critic:                                                                 
│         optimizer:                                                            
│           _target_: torch.optim.Adam                                          
│           lr: 8.0e-05                                                         
│           eps: 1.0e-05                                                        
│           weight_decay: 0                                                     
│           betas:                                                              
│           - 0.9                                                               
│           - 0.999                                                             
│         dense_act: torch.nn.SiLU                                              
│         mlp_layers: 3                                                         
│         layer_norm:                                                           
│           cls: sheeprl.models.models.LayerNorm                                
│           kw:                                                                 
│             eps: 0.001                                                        
│         dense_units: 640                                                      
│         per_rank_target_network_update_freq: 1                                
│         tau: 0.02                                                             
│         bins: 255                                                             
│         clip_gradients: 100.0                                                 
│       gamma: 0.996996996996997                                                
│       lmbda: 0.95                                                             
│       horizon: 15                                                             
│       replay_ratio: 1                                                         
│       learning_starts: 1024                                                   
│       per_rank_pretrain_steps: 0                                              
│       per_rank_sequence_length: 64                                            
│       cnn_layer_norm:                                                         
│         cls: sheeprl.models.models.LayerNormChannelLast                       
│         kw:                                                                   
│           eps: 0.001                                                          
│       mlp_layer_norm:                                                         
│         cls: sheeprl.models.models.LayerNorm                                  
│         kw:                                                                   
│           eps: 0.001                                                          
│       dense_units: 640                                                        
│       mlp_layers: 3                                                           
│       dense_act: torch.nn.SiLU                                                
│       cnn_act: torch.nn.SiLU                                                  
│       unimix: 0.01                                                            
│       hafner_initialization: true                                             
│       player:                                                                 
│         discrete_size: 32                                                     
│                                                                               
├── buffer
│   └── size: 1000                                                              
│       memmap: true                                                            
│       validate_args: false                                                    
│       from_numpy: false                                                       
│       checkpoint: true                                                        
│                                                                               
├── checkpoint
│   └── every: 1600                                                             
│       resume_from: null                                                       
│       save_last: true                                                         
│       keep_last: 5                                                            
│                                                                               
├── env
│   └── id: PongNoFrameskip-v4                                                  
│       num_envs: 8                                                             
│       frame_stack: 1                                                          
│       sync_env: false                                                         
│       screen_size: 64                                                         
│       action_repeat: 1                                                        
│       grayscale: false                                                        
│       clip_rewards: false                                                     
│       capture_video: true                                                     
│       frame_stack_dilation: 1                                                 
│       actions_as_observation:                                                 
│         num_stack: -1                                                         
│         noop: You MUST define the NOOP                                        
│         dilation: 1                                                           
│       max_episode_steps: 27000                                                
│       reward_as_observation: false                                            
│       wrapper:                                                                
│         _target_: gymnasium.wrappers.AtariPreprocessing                       
│         env:                                                                  
│           _target_: gymnasium.make                                            
│           id: PongNoFrameskip-v4                                              
│           render_mode: rgb_array                                              
│         noop_max: 30                                                          
│         terminal_on_life_loss: false                                          
│         frame_skip: 1                                                         
│         screen_size: 64                                                       
│         grayscale_obs: false                                                  
│         scale_obs: false                                                      
│         grayscale_newaxis: true                                               
│       test_seedlist: null                                                     
│       frame_skip: 1                                                           
│                                                                               
├── fabric
│   └── _target_: lightning.fabric.Fabric                                       
│       devices: 1                                                              
│       num_nodes: 1                                                            
│       strategy: auto                                                          
│       accelerator: gpu                                                        
│       precision: 32-true                                                      
│       callbacks:                                                              
│       - _target_: sheeprl.utils.callback.CheckpointCallback                   
│         keep_last: 5                                                          
│                                                                               
└── metric
    └── log_every: 8                                                            
        disable_timer: false                                                    
        log_level: 1                                                            
        sync_on_compute: false                                                  
        aggregator:                                                             
          _target_: sheeprl.utils.metric.MetricAggregator                       
          raise_on_missing: false                                               
          metrics:                                                              
            Rewards/rew_avg:                                                    
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Game/ep_len_avg:                                                    
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Loss/world_model_loss:                                              
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Loss/value_loss:                                                    
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Loss/policy_loss:                                                   
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Loss/observation_loss:                                              
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Loss/reward_loss:                                                   
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Loss/state_loss:                                                    
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Loss/continue_loss:                                                 
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            State/kl:                                                           
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            State/post_entropy:                                                 
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            State/prior_entropy:                                                
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Grads/world_model:                                                  
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Grads/actor:                                                        
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
            Grads/critic:                                                       
              _target_: torchmetrics.MeanMetric                                 
              sync_on_compute: false                                            
        logger:                                                                 
          _target_: lightning.fabric.loggers.TensorBoardLogger                  
          name: 2024-09-19_14-55-26_dreamer_v3_PongNoFrameskip-v4_42            
          root_dir: logs/runs/dreamer_v3/PongNoFrameskip-v4                     
          version: null                                                         
          default_hp_metric: true                                               
          prefix: ''                                                            
          sub_dir: null                                                         
belerico commented 2 months ago

Hi @chrisgao99! The problem you see seems related to the frame_skip and action_repeat values that differ from the defaults. Can you try to leave those as default?

chrisgao99 commented 2 months ago

Sure. I can definitely try that. But do you think frame skip and action repeat can make such a big difference? I think they just control the data collection with real env and action execute during testing? What should be the normal time of training pong with dreamer v3?

besides, should frame skip and action repeat always be the same? If they should, why do we have two parameters?

michele-milesi commented 2 months ago

Hi @chrisgao99, yes, action repeat can make a difference during training. Atari environments use the frame_skip parameter to specify the action repetition. We decided to standardize the name of this parameter for Atari environments and always use action_repeat. When you change its value, the frame_skip parameter is automatically set (unless it is redefined, as you did).

As for the training duration of dreamer_v3, it is rather expensive, especially if you set the replay_ratio parameter to 1. Since you set the total steps to 1600000, you can try reducing it to 0.25 or 0.125.

chrisgao99 commented 1 month ago

Thanks a lot for your advice. it's very helpful for my experiment with atari.

I got a few more questions regarding the dreamer v3 configs and I just opened a new issue.