erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Add label smoothing, z_loss and ignore <=0 tokens in loss calculation #102

Closed yhavinga closed 4 months ago

yhavinga commented 4 months ago

Train step loss function makes use of compute_weighted_cross_entropy_and_accuracy() an extended version of compute_weighted_cross_entropy() that also calculates accuracy. This function supports label smoothing and z_loss.

The eval step loss function still makes use of cross_entropy_loss_and_accuracy(), when calling it, the loss function also incorporates (in)validity of label tokens if they are <=0, together with the attention masks.

yhavinga commented 4 months ago

@erfanzar I just checked output of the first model I trained on it and it produces garbage.. (even though loss was quite nice). It could be anywhere, labels in the dataset, not attending on eos token, shifted wrong.. I think its best to revert this commit and I'll try to figure out where its failing later (will take a while though, not in the opportunity to work on it for a while). Apologies for the noise!

erfanzar commented 4 months ago

@yhavinga that's fine and thank for reporting.

yhavinga commented 4 months ago

@erfanzar :-) Pretty sure I've found it: attention mask was shifted right instead of left in this line:

             loss_weights = jnp.where(
-                (batch["attention_mask"][:, :-1] != 0) & (labels > 0), 1, 0
+                (batch["attention_mask"][:, 1:] != 0) & (labels > 0), 1, 0
             )

I'm currently running tests with the new loss function against a baseline trained with the 'old' loss function. New loss function supports ignoring loss for token labels with -100, z_loss and label smoothing. Already tested new loss function without any new option on, and with ignoring -100 label tokens: no bugs encountered yet.

erfanzar commented 4 months ago

@yhavinga ;\ seems like that's fine if you passed all of the tests you can make a pull request and I'll optimize that (I'll use jax.lax instead of some other APIs, and add some TPU Pallas calls) and if you want we can update beta branch and work on that so we can simply just train some models and test

yhavinga commented 4 months ago

@erfanzar all tests have turned out ok yay! - I need to rebase next (my tests and fixed code was all based on the main branch from ~6 feb). Is it your preference that I rebase on the beta branch instead of main?

erfanzar commented 4 months ago

yes cause im about to add new Trainers and Change how loss functions kinda work so it would nice if you could rebase that on beta

yhavinga commented 4 months ago

@erfanzar so I have two branches with the fixed attention_mask shifting

1) https://github.com/yhavinga/EasyDeL/tree/ignore_token_label_smooth_z_loss -- on main branch from 7 feb 2) https://github.com/yhavinga/EasyDeL/tree/beta_ignore_token_label_smooth_z_loss -- on the beta branch

The first one is all ok, everything working. The second one I get errors if I run e.g. the python_test/easy_causal_language_model_trainer_test.py -- this gives errors also without my patch.

Traceback (most recent call last):
  File "/home/yeb/Developer/yhavinga/zuchtax/EasyDeL/lib/python/EasyDel/modules/flax_modelling_utils.py", line 438, in block_wise_ffn
    inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
  File "/home/yeb/Developer/yhavinga/zuchtax/venv/lib/python3.10/site-packages/einops/einops.py", line 483, in rearrange
    return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths)
  File "/home/yeb/Developer/yhavinga/zuchtax/venv/lib/python3.10/site-packages/einops/einops.py", line 420, in reduce
    raise EinopsError(message + '\n {}'.format(e))
einops.EinopsError:  Error while processing rearrange-reduction pattern "b (c n) d -> b c n d".
 Input tensor shape: (1, 128, 128). Additional info: {'c': 1024}.
 Shape mismatch, can't divide axis of length 128 in chunks of 1024

maybe it is due to a mismatch in a package? $ pip freeze | grep ax distrax==0.1.5 flax==0.7.5 jax==0.4.23 jaxlib==0.4.23 optax==0.1.8 orbax-checkpoint==0.5.0 rlax==0.1.6

erfanzar commented 4 months ago

@yhavinga Hello and thank you for contributing, if you think that's working correctly and it can train the model you can make a pull request again for that, and the error your getting right now is from scan_mlp and you can disable that or use lower chunks. you can disable that with giving use_scan_mlp=False to model config or change scan_mlp_chunk_size=64<or any other value>

yhavinga commented 4 months ago

@erfanzar Just tested a bit on tpu-v4-8:

I'm still testing with python_test/easy_causal_language_model_trainer_test.py When it initialized Qwen2Config, all ok with use_scan_mlp=False, until the call to super.init() is called with -**kwargs (that does NOT include use_scan_mlp) is called, then enters EasyDelPretrainedConfig and now use_scan_mlp is not passed, so set to this class's default use_scan_mlp=True.

erfanzar commented 4 months ago

it's ok if the loss function works for training you can re-create the functions on FJFormer then we can port that and test that on EasyDeL.

yhavinga commented 4 months ago

@erfanzar I'm sorry, but I don't understand. Can I reach you on a discord anywhere? E.g. JaxLLM https://discord.gg/T5c48fDT ?

erfanzar commented 3 months ago

@yhavinga Yes here's my discord id citifer.

and I said you could make functional changes that you need on FJformer side, and make a new pull request or re-open this pull request and I can debug if there be any issue.

erfanzar commented 3 months ago

@erfanzar I'm sorry, but I don't understand. Can I reach you on a discord anywhere? E.g. JaxLLM https://discord.gg/T5c48fDT ?

we can continue here I guess https://discord.gg/Eh8mxrmD