Closed slerman12 closed 2 years ago
You'd have to keep a non-augmented copy of the state around and run it through the encoder -- right now, we only feed transformed images into the encoder. You'd just need to modify the first part of the forward
method in the model, around here. You'd then need to calculate the loss somewhere, either in forward
or in the loss
function in algos
, and return it and add it to the overall loss used by the optimizer in optimize_agent
.
Would it be possible then to disable augmentation until the optimize_agent
function, and then augment there?
One more thing I didn't mention, I need to get the Q values for the state into the optimize_agent
function as well. Is it possible to retrieve those independent from the loss?
Basically I need (1) the original and augmented state, (2) a way to compute Q values, both in the optimize_agent
function
Alternatively, would it be possible to do it all in forward
? It looks like forward includes a term called spr_loss
. Could I just add my loss term to that term, as in,
spr_loss = spr_loss + F(s, s_aug)
and it looks like head_forward
does the Q value prediction
optimize_agent
really isn't designed for that sort of thing -- it's just a wrapper that adds together losses, calls the optimizer, and does some logging. You won't have easy access to important parts of the model from there, due to namespaces. As you said, forward
is probably the easiest place to do it.
You can get the raw Q-values by calling from_categorical(log_pred_ps.exp(), 10, False)
in loss
or forward
. From there you can use them in your loss calculation or return them back up to optimize_agent
, if you just want to be logging them.
Thank you so much! Yes, I'm running it now and it seems to be working... will report back tomorrow if not. Hopefully all goes smoothly. Weights and biases is a really cool way to track results, but I've never used it before. Not to be too much of a bother, but would you happen to have some tips about how to use it to run a bunch of environments and random seeds and compile those results quickly? I'm not sure WAB takes care of that sort of thing, but right now the logs are a bit disorganized since each run is assigned a random ID.
Also, about what is the expected run time?
You can use WandB's grouping feature over environments to extract most of the information you'd need, but it's easiest to download the results as a CSV and analyze them offline with Pandas, for the most part (also because that way you can make your own figures).
As for runtime, it depends on your hardware and what you added, but base SPR is about ~4 hours on the GPUs we have (V100s).
Any chance it can run on CPUs? We only have 4 V100s.
You're welcome to try, but my guess is that it would take ~24 hours on CPU. I don't think our GPU utilization is even close to 100%, though, so maybe you could get multiple runs squeezed on the same card.
Any chance you'd want to run the experiment on your GPUs and be a co-author on the ICLR 2022 submission? haha. There's so little time and one of the NeurIPS reviewers strongly advised that we implement our technique with SPR; otherwise reviews were pretty good. Unorthodox, but it would save us a lot of pain
I might be able to help you out -- I'm pretty sure it wouldn't violate school policy as long as I were on the paper -- but it depends on how much you need to run. It might be possible to do one full set of runs before the deadline, but this isn't a great time to be looking for compute at Mila either.
Do you want to email me with more details on what you're doing?
I sent to MaxASchwarzer@gmail.com. Should I send to schwarzm@mila.quebec?
I was so confused because I use one A100 and 16 cpu on cluster. However, one run takes 9.5h, which is much bigger than what you said 4h.
Can you give me some hints about the difference?
Sorry, this is urgent. I am trying to modify the overall model loss with a term that depends on both the original state and the augmented state. I'm looking at the code and am pretty confused. Is there a way I can simply add a loss term that depends on those two things? e.g.,
loss = loss + F(s, s_aug)