iree-org / iree-jax

Apache License 2.0
48 stars 19 forks source link

Change batch size to 1 and add an example #56

Closed wangkuiyi closed 1 year ago

wangkuiyi commented 1 year ago

This implementation of GPT-2 is great because not only is it JAX-native, but it also uses KV-caching, a technique that helps generate text quickly. This pull request shows how KV-caching works with an example.

I changed the batch size from 4 to 1, for two reasons:

  1. This KV-cache enabled GPT-2 is a perfect fit to run in an iOS or Android app. The input batch size 1 allows users to type one sentence at a time.

  2. Batch size 1 also makes it easy to type the prompt when debugging:

    iree-run-module --module=/tmp/gpt2-vmvx.vmfb \
    --device="local-task" \
    --function=encode \
    --input="1x8xi32=[[15496 11 616 3290 318 13779 0 0]]" \
    --input="1xi32=[6]"

    If batch size is 4, we will need to type in a four-row matrix as the first input.