Closed yhavinga closed 4 months ago
@erfanzar I just checked output of the first model I trained on it and it produces garbage.. (even though loss was quite nice). It could be anywhere, labels in the dataset, not attending on eos token, shifted wrong.. I think its best to revert this commit and I'll try to figure out where its failing later (will take a while though, not in the opportunity to work on it for a while). Apologies for the noise!
@yhavinga that's fine and thank for reporting.
@erfanzar :-) Pretty sure I've found it: attention mask was shifted right instead of left in this line:
loss_weights = jnp.where(
- (batch["attention_mask"][:, :-1] != 0) & (labels > 0), 1, 0
+ (batch["attention_mask"][:, 1:] != 0) & (labels > 0), 1, 0
)
I'm currently running tests with the new loss function against a baseline trained with the 'old' loss function. New loss function supports ignoring loss for token labels with -100, z_loss and label smoothing. Already tested new loss function without any new option on, and with ignoring -100 label tokens: no bugs encountered yet.
@yhavinga ;\ seems like that's fine if you passed all of the tests you can make a pull request and I'll optimize that (I'll use jax.lax
instead of some other APIs, and add some TPU Pallas calls) and if you want we can update beta branch and work on that so we can simply just train some models and test
@erfanzar all tests have turned out ok yay! - I need to rebase next (my tests and fixed code was all based on the main branch from ~6 feb). Is it your preference that I rebase on the beta branch instead of main?
yes cause im about to add new Trainers and Change how loss functions kinda work so it would nice if you could rebase that on beta
@erfanzar so I have two branches with the fixed attention_mask shifting
1) https://github.com/yhavinga/EasyDeL/tree/ignore_token_label_smooth_z_loss -- on main branch from 7 feb 2) https://github.com/yhavinga/EasyDeL/tree/beta_ignore_token_label_smooth_z_loss -- on the beta branch
The first one is all ok, everything working. The second one I get errors if I run e.g. the python_test/easy_causal_language_model_trainer_test.py -- this gives errors also without my patch.
Traceback (most recent call last):
File "/home/yeb/Developer/yhavinga/zuchtax/EasyDeL/lib/python/EasyDel/modules/flax_modelling_utils.py", line 438, in block_wise_ffn
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
File "/home/yeb/Developer/yhavinga/zuchtax/venv/lib/python3.10/site-packages/einops/einops.py", line 483, in rearrange
return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths)
File "/home/yeb/Developer/yhavinga/zuchtax/venv/lib/python3.10/site-packages/einops/einops.py", line 420, in reduce
raise EinopsError(message + '\n {}'.format(e))
einops.EinopsError: Error while processing rearrange-reduction pattern "b (c n) d -> b c n d".
Input tensor shape: (1, 128, 128). Additional info: {'c': 1024}.
Shape mismatch, can't divide axis of length 128 in chunks of 1024
maybe it is due to a mismatch in a package? $ pip freeze | grep ax distrax==0.1.5 flax==0.7.5 jax==0.4.23 jaxlib==0.4.23 optax==0.1.8 orbax-checkpoint==0.5.0 rlax==0.1.6
@yhavinga Hello and thank you for contributing, if you think that's working correctly and it can train the model you can make a pull request again for that, and the error your getting right now is from scan_mlp and you can disable that or use lower chunks.
you can disable that with giving use_scan_mlp=False
to model config or change scan_mlp_chunk_size=64<or any other value>
@erfanzar Just tested a bit on tpu-v4-8:
I'm still testing with python_test/easy_causal_language_model_trainer_test.py When it initialized Qwen2Config, all ok with use_scan_mlp=False, until the call to super.init() is called with -**kwargs (that does NOT include use_scan_mlp) is called, then enters EasyDelPretrainedConfig and now use_scan_mlp is not passed, so set to this class's default use_scan_mlp=True.
it's ok if the loss function works for training you can re-create the functions on FJFormer then we can port that and test that on EasyDeL.
@erfanzar I'm sorry, but I don't understand. Can I reach you on a discord anywhere? E.g. JaxLLM https://discord.gg/T5c48fDT ?
@yhavinga Yes here's my discord id citifer
.
and I said you could make functional changes that you need on FJformer side, and make a new pull request or re-open this pull request and I can debug if there be any issue.
@erfanzar I'm sorry, but I don't understand. Can I reach you on a discord anywhere? E.g. JaxLLM https://discord.gg/T5c48fDT ?
we can continue here I guess https://discord.gg/Eh8mxrmD
Train step loss function makes use of compute_weighted_cross_entropy_and_accuracy() an extended version of compute_weighted_cross_entropy() that also calculates accuracy. This function supports label smoothing and z_loss.
The eval step loss function still makes use of cross_entropy_loss_and_accuracy(), when calling it, the loss function also incorporates (in)validity of label tokens if they are <=0, together with the attention masks.