mila-iqia / spr

Code for "Data-Efficient Reinforcement Learning with Self-Predictive Representations"
https://arxiv.org/abs/2007.05929
MIT License
157 stars 32 forks source link

Quick question #26

Closed slerman12 closed 2 years ago

slerman12 commented 2 years ago

Sorry, this is urgent. I am trying to modify the overall model loss with a term that depends on both the original state and the augmented state. I'm looking at the code and am pretty confused. Is there a way I can simply add a loss term that depends on those two things? e.g.,

loss = loss + F(s, s_aug)

MaxASchwarzer commented 2 years ago

You'd have to keep a non-augmented copy of the state around and run it through the encoder -- right now, we only feed transformed images into the encoder. You'd just need to modify the first part of the forward method in the model, around here. You'd then need to calculate the loss somewhere, either in forward or in the loss function in algos, and return it and add it to the overall loss used by the optimizer in optimize_agent.

slerman12 commented 2 years ago

Would it be possible then to disable augmentation until the optimize_agent function, and then augment there?

One more thing I didn't mention, I need to get the Q values for the state into the optimize_agent function as well. Is it possible to retrieve those independent from the loss?

Basically I need (1) the original and augmented state, (2) a way to compute Q values, both in the optimize_agent function

slerman12 commented 2 years ago

Alternatively, would it be possible to do it all in forward? It looks like forward includes a term called spr_loss. Could I just add my loss term to that term, as in,

spr_loss = spr_loss + F(s, s_aug)

and it looks like head_forward does the Q value prediction

MaxASchwarzer commented 2 years ago

optimize_agent really isn't designed for that sort of thing -- it's just a wrapper that adds together losses, calls the optimizer, and does some logging. You won't have easy access to important parts of the model from there, due to namespaces. As you said, forward is probably the easiest place to do it.

You can get the raw Q-values by calling from_categorical(log_pred_ps.exp(), 10, False) in loss or forward. From there you can use them in your loss calculation or return them back up to optimize_agent, if you just want to be logging them.

slerman12 commented 2 years ago

Thank you so much! Yes, I'm running it now and it seems to be working... will report back tomorrow if not. Hopefully all goes smoothly. Weights and biases is a really cool way to track results, but I've never used it before. Not to be too much of a bother, but would you happen to have some tips about how to use it to run a bunch of environments and random seeds and compile those results quickly? I'm not sure WAB takes care of that sort of thing, but right now the logs are a bit disorganized since each run is assigned a random ID.

slerman12 commented 2 years ago

Also, about what is the expected run time?

MaxASchwarzer commented 2 years ago

You can use WandB's grouping feature over environments to extract most of the information you'd need, but it's easiest to download the results as a CSV and analyze them offline with Pandas, for the most part (also because that way you can make your own figures).

As for runtime, it depends on your hardware and what you added, but base SPR is about ~4 hours on the GPUs we have (V100s).

slerman12 commented 2 years ago

Any chance it can run on CPUs? We only have 4 V100s.

MaxASchwarzer commented 2 years ago

You're welcome to try, but my guess is that it would take ~24 hours on CPU. I don't think our GPU utilization is even close to 100%, though, so maybe you could get multiple runs squeezed on the same card.

slerman12 commented 2 years ago

Any chance you'd want to run the experiment on your GPUs and be a co-author on the ICLR 2022 submission? haha. There's so little time and one of the NeurIPS reviewers strongly advised that we implement our technique with SPR; otherwise reviews were pretty good. Unorthodox, but it would save us a lot of pain

MaxASchwarzer commented 2 years ago

I might be able to help you out -- I'm pretty sure it wouldn't violate school policy as long as I were on the paper -- but it depends on how much you need to run. It might be possible to do one full set of runs before the deadline, but this isn't a great time to be looking for compute at Mila either.

Do you want to email me with more details on what you're doing?

slerman12 commented 2 years ago

I sent to MaxASchwarzer@gmail.com. Should I send to schwarzm@mila.quebec?

yueyang130 commented 2 years ago

I was so confused because I use one A100 and 16 cpu on cluster. However, one run takes 9.5h, which is much bigger than what you said 4h.

Can you give me some hints about the difference?