HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 5 forks source link

Causality Test #78

Open ClashLuke opened 1 year ago

ClashLuke commented 1 year ago

Currently, we have to manually verify that a modification doesn't accidentally leak information, which can be prone to errors. Especially in situations where only some tokens can see other future tokens, this can be difficult to notice using only the loss curves. That's why we should need to introduce a test that ensures our model cannot future tokens, as that'd make it much easier to predict future tokens.

ClashLuke commented 1 year ago

My current best approach would be to initialize a regular model (or optionally layer for unit tests), compute the forward pass, and backpropagate through the loss at one specific position rather than the mean. This way, we can look at the input's gradients and see if any future tokens have a gradient != 0. In a separate test, we could also check if at least one of the past tokens has a gradient != 0 to ensure that the model even looks at the input.\ The issue with this approach is that leaks can be difficult to isolate, and we've already had it multiple times that a leak occurred for a few tokens to other singular tokens but not from all to all. That's why we would need context_size separate backward passes, which can get expensive.