stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

Tutorial Suggestion #8

Closed pharringtonp19 closed 1 year ago

pharringtonp19 commented 1 year ago

Really enjoying the tutorial. I have one small suggestion. For the attention_scores_draft, it might be helpful to include the necessary tensors to play with the function.

Pos = hax.Axis("position", 1024)
Key = hax.Axis("key", 64)
KPos = Pos.alias("key_position")
query = hax.random.uniform(PRNGKey(0), (Pos, Key))
key = hax.random.uniform(PRNGKey(1), (Key, KPos))
def attention_scores_draft(Key, query, key): 
    scores = hax.dot(Key, query, key)  / jnp.sqrt(Key.size)
    scores = hax.nn.softmax(scores, KPos)
    return scores
attention_scores_draft(Key, query, key)
dlwh commented 1 year ago

That makes sense!

dlwh commented 1 year ago

Thanks! Please let me know any more thoughts/suggestions you might have!