This implementation of GPT-2 is great because not only is it JAX-native, but it also uses KV-caching, a technique that helps generate text quickly. This pull request shows how KV-caching works with an example.
I changed the batch size from 4 to 1, for two reasons:
This KV-cache enabled GPT-2 is a perfect fit to run in an iOS or Android app. The input batch size 1 allows users to type one sentence at a time.
Batch size 1 also makes it easy to type the prompt when debugging:
This implementation of GPT-2 is great because not only is it JAX-native, but it also uses KV-caching, a technique that helps generate text quickly. This pull request shows how KV-caching works with an example.
I changed the batch size from 4 to 1, for two reasons:
This KV-cache enabled GPT-2 is a perfect fit to run in an iOS or Android app. The input batch size 1 allows users to type one sentence at a time.
Batch size 1 also makes it easy to type the prompt when debugging:
If batch size is 4, we will need to type in a four-row matrix as the first input.