dotchen / WorldOnRails

(ICCV 2021, Oral) RL and distillation in CARLA using a factorized world model
https://dotchen.github.io/world_on_rails/
MIT License
167 stars 29 forks source link

Requesting some more training details #17

Closed aaronh65 closed 3 years ago

aaronh65 commented 3 years ago

Hey again, just wanted to ask for a bit more training detail. I'm trying to reproduce results (computing benchmark metrics on the Leaderboard validation routes), and the pretrained weights seem to do better than the weights I train on my own setup. I have two questions about how main_model_10.th is trained

  1. How large was the main_dataset used to train that model?
  2. How long did it take to train the model? I'm assuming it was 10 epochs total, which would've taken multiple days for the 100 Gb main dataset I was using

Also, thank you for the reply on the other issue - I'll need to mull over your answers and read through the code some more to understand what's going on. I'll reply to that issue once I've thought about it some more!

dotchen commented 3 years ago

How large was the main_dataset used to train that model?

So the main_dataset we used to train our leaderboard model is the one that we released in DATASET.md. It consists of roughly one million frames. This is also mentioned in our paper. See Section 5 in the paper.

How long did it take to train the model? I'm assuming it was 10 epochs total, which would've taken multiple days for the 100 Gb main dataset I was using

10 Epochs. It took me 4 days to train on two titan Xp with the default batch size. 100GB sounds a bit small, the dataset I used (RAILS_1M) before compression is 3.4TB in the lmdb format.

aaronh65 commented 3 years ago

Thanks for the reply, I had forgotten about that part in the paper. You mention that 1M frames is approximately 14 hours of driving, but later on you mention that data collection occurs at 4 Hz. Wouldn't that imply that 1M frames / 4 frames per second / 3600 seconds per hour = 69 hours? Or is the data collected at 20 Hz, while the policy runs at 4 Hz?

Also separately, I wanted to modify the get_reward logic so that intersections have no speed zones when the traffic light is yellow as well as red. Do you have a recommendation for how I can do that?

dotchen commented 3 years ago

Ah yes good catch, it should be 69 hours. I used the wrong FPS to compute it (should be 4 due to the frame skip instead of 20 as you mentioned). Thank you for letting me know, we will update it asap on arxiv.

Also separately, I wanted to modify the get_reward logic so that intersections have no speed zones when the traffic light is yellow as well as red. Do you have a recommendation for how I can do that?

The easiest way is to comment out line 144 in rails/bellman.py. If you also do not want it to brake then also remove the (red>0) at line 139.

aaronh65 commented 3 years ago

Thanks for the tip! So this should encourage the agent to stop when traffic lights are yellow right?

Another quick question - say I've already run data_phase 2 on a data_dir, but I've made some changes to some labeling logic. If I run data_phase 2 again, will the new labels overwrite the old labels?

dotchen commented 3 years ago

So this should encourage the agent to stop when traffic lights are yellow right?

Sorry I parsed your question incorrectly previously, if you want the agent to also stop at yellow lights you don't need to change anything, the default already handles that.

will the new labels overwrite the old labels?

Yup, this will be automatically handled by lmdb.

aaronh65 commented 3 years ago

Great, thanks! Another question I have - for training the image policy, your paper sets up the loss to maximize the expected return of the policy as well as an entropy regularizer term. But in the code act_loss is defined as the KL divergence between the policy's action probability distribution and the softmax-ed distribution of Q values from the training state.

How come the loss in the code is formulated this way? Does that equate the loss function as written in the paper?

dotchen commented 3 years ago

No prob!

Regarding the equation 3 in the paper, it is actually mathematically equivalent, so we just wrote down the simplified version.

dotchen commented 3 years ago

Here is a small derivation:

image

aaronh65 commented 3 years ago

Sounds good, thanks for the info!

While running data_phase1 in parallel, I often get the following error, always when trying to turn traffic_manager's sync mode on or off.

File "/home/aaronhua/WorldOnRails/leaderboard/leaderboard/leaderboard_evaluator.py", line 162, in _cleanup
    self.traffic_manager.set_synchronous_mode(False)
RuntimeError: rpc::timeout: Timeout of 2000ms while calling RPC function 'set_synchronous_mode'
(pid=55313)   File "/home/aaronhua/WorldOnRails/leaderboard/leaderboard/leaderboard_evaluator.py", line 238, in _load_and_wait_for_world
(pid=55313)     self.traffic_manager.set_synchronous_mode(True)
(pid=55313) RuntimeError: rpc::timeout: Timeout of 2000ms while calling RPC function 'set_synchronous_mode'

My guess is that when running in parallel, if one worker happens to be writing data to disk at the end of a run while another worker is trying to reset the simulator, the entire system slows down and triggers the 2 second traffic_manager timeout. I modified the scripts/launch_carla to default to a server timeout of like 20 minutes but still get the above 2s timeout message. I've tried looking around to see if we can specify the traffic_manager specific timeout, but there's no comprehensive list of CarlaUE4.sh command line options.

The closest I can find is here, implying that the C++ implementation of the TrafficManager has a method that sets the set_synchronous_mode timeout. The Python bindings don't appear to have this method though. Do you have an idea of how I'd resolve this issue?

dotchen commented 3 years ago

This error usually happens when you have a preceding error, please check if it throws any other error before it.

P.S. this is a separate issue (together with others in this thread) than indicated in the title. Please consider open a separate thread so it would be easier for others to navigate. Thanks!