Open neel04 opened 1 year ago
Hi Neel,
Thanks so much for looking into this and opening a PR. We're looking into the issue and the solution and I'll follow up again soon.
Thanks again!
Perhaps a simpler way to resolve this would be to weight the loss exponentially:
steps = n + k
...
weight_dict = {i: 2 * (1 - (i / max_iters)) for i in range(max_iters)} # coefficients are arbitrary and can be adjusted
loss_progressive = criterion(outputs, targets)
loss_progressive = loss_progressive * torch.tensor([weight_dict[steps]]).to(device)
Where you can just return n + k
instead of just k
with the function called during training https://github.com/aks2203/deep-thinking/blob/152fb1182fb5dc11275859407ceec46102c20e42/deepthinking/utils/training.py#L56
Its already passed in the main training loop, but discarded - so that would be a simpler change than what I suggested
Hi Neel,
I like the get_skewed_n_and_k()
function above. I understand completely that the sampling we had wasn't uniform and that this version is. But one thing I'm still wondering about is how this affects generalization. Have you compared the performance of the two methods? It would be great to point to some results -- even just a few trials in one problem domain -- to see the impact.
Is that something you already have? I can do this comparison on my end if not.
What do you think?
Hey Avi,
I've done some heavy modifications to the OG architecture and test on a totally different suite of algorithmic tasks. I haven't done any significant ablations, but for simpler tasks, it appears that this re-weighting is not needed at all - OOD generalisation can be achieved through entirely different methods.
For more complex tasks, I do see that it helps extrapolation for a few percentage points (roughly, β 5-10% points, whole string test accuracy) but its complicated by the fact that the best extrapolation accuracy here is about 65%. So very likely, there are several other missing priors needed to extrapolate on my task; skewed n+k
might help somewhat, but ultimately be not as effective for full OOD generalization.
But even then, I can't think any scenario where such an (extreme) imbalance would be helpful. All results, as insignificant they may be, show similar performance or slight improvements. I suspect this issue may actually impact more at increased number of iterations (>30-50) and for harder tasks preventing it to effectively distribute work across each individual iteration.... But that has to be properly tested, I suppose π
LMK if you want me to run some changes on my side and see how they transfer across to other tasks
Cheers, Neel
https://github.com/aks2203/deep-thinking/blob/6b8aa99bfc3456541cc842d67f93a3fa41aa9055/deepthinking/utils/training.py#L42-L47
Subtle bug, but you sample two random variables -
n
andk
. The distribution of the sum,n + k
will NOT be uniform.Example, for
max_iters=10
, the distribution of the sum is heavily skewed making the model worse at generalizing iteration wise, especially for harder tasks which require learning harder dynamicsOne solution might be to skew the sampling of
k
s.t the sum approaches a random distribution:Which yields:
Additionally, because the
skew
is also randomly sampled, the distribution shifts a little bit each time so as to provide better coverage for alliterations
.LMK if you want a PR @aks2203! π
Have a great day! Neel