PAIR-code / tiny-transformers

Apache License 2.0
14 stars 2 forks source link

Add decoder-style causal loss for predicting all tokens at once #17

Open iislucas opened 2 months ago

iislucas commented 2 months ago

Context:

  1. A unit test for transformer setup and training: https://github.com/PAIR-code/tiny-transformers/blob/main/animated-transformer/src/lib/trainer/basic_transformer_trainer.spec.ts

  2. Transformer implementation: https://github.com/PAIR-code/tiny-transformers/blob/main/animated-transformer/src/lib/transformer/transformer_gtensor.ts

  3. GTensor is a class that encapsulates named tensors. See these unit test to get a sense of it: https://github.com/PAIR-code/tiny-transformers/blob/main/animated-transformer/src/lib/gtensor/gtensor.spec.ts

  4. The current loss function: https://github.com/PAIR-code/tiny-transformers/blob/main/animated-transformer/src/lib/transformer/transformer_gtensor.ts#L372

Goal: Implement the standard decoder transformer loss function of providing gradients from every token simultaneously. (e.g. gpt2-style)