stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

Model Inference with Haliax? #10

Closed rohan-mehta-1024 closed 8 months ago

rohan-mehta-1024 commented 1 year ago

I am trying to replicate nanoGPT with Haliax but am running into an error when trying to do inference with the model. The model expects a named array as input, but unlike during training, the vector of tokens I provide will not always fill up the entire context/block size, and so will not have the Block axis – in fact, it should probably be an unnamed array for this reason because there is no real meaningful axis to represent. But I can't pass in a normal Jax array to my Haliax model. So how should inference be done? I tried checking out the Levanter repo but was unable to determine this. Apologies if it is an elementary question.

dlwh commented 1 year ago

Hi @rohan-mehta-1024 ,

That sounds like a great project! Just so we're clear, you're using Block as the name of the sequence length Axis? I still think it's meaningful to give it a name: the axis has a meaning. Also note that not all vectors that have an axis called "block" will have to have the same length for block. That said, because of the way jax works, you have to recompile every time you pass in a new length, so I'd recommend padding to some set of lengths.

I've personally not implemented inference in Haliax/Levanter yet, but I don't think there should be too many surprises.

rohan-mehta-1024 commented 1 year ago

Thank you this clarification! Yes, I am using Block as the axis to represent sequence length / context size. I ended up doing inference like this and was just wondering if this is "idiomatic" Haliax:

def generate(self, seq: jax.Array, max_new_tokens: int, key: jr.PRNGKey):
    full_seq = seq 
    for _ in range(max_new_tokens):
        block = hax.named(self._pad(full_seq[-Block.size:]), (Block,))
        logits = self(block)
        last_token_logits = logits[Block, -1, Embed, :] 
        key, subkey = jr.split(key)
        next_token = hax.random.categorical(key=subkey, logits=last_token_logits, axis=Embed)
        next_token = jnp.expand_dims(next_token.array, axis=0) # reshape token to [token]
        full_seq = jnp.concatenate([full_seq, next_token])
    return full_seq

So basically I take in a Jax array of some arbitrary length, take the last Block elements of that to fill the context size (or pad if there is not enough), predict the next token and add it to the end of that Jax array, and then repeat.

dlwh commented 1 year ago

I think this is good/reasonable. It would be good to add pad and similar (might be a decent PR...)

One could probably haliax-ify the seq/full_seq stuff, but i think it's ok

I think "Embed" is a weird choice of axis name here. ordinarily i'd expect Vocab or TokenId or something?

rohan-mehta-1024 commented 1 year ago

Ok, thanks for that confirmation. It's a little confusing but Embed and Vocab have the same size here because I'm testing things out with a bigram model before I move to the full transformer architecture. I'm kinda new to this, but would like to support this library as I think it's really cool, and so will try to submit a PR for pad / unpad functions.

dlwh commented 8 months ago

closing this, just b/c there's no definition of done