Closed pharringtonp19 closed 1 year ago
Really enjoying the tutorial. I have one small suggestion. For the attention_scores_draft, it might be helpful to include the necessary tensors to play with the function.
attention_scores_draft
Pos = hax.Axis("position", 1024) Key = hax.Axis("key", 64) KPos = Pos.alias("key_position") query = hax.random.uniform(PRNGKey(0), (Pos, Key)) key = hax.random.uniform(PRNGKey(1), (Key, KPos)) def attention_scores_draft(Key, query, key): scores = hax.dot(Key, query, key) / jnp.sqrt(Key.size) scores = hax.nn.softmax(scores, KPos) return scores attention_scores_draft(Key, query, key)
That makes sense!
Thanks! Please let me know any more thoughts/suggestions you might have!
Really enjoying the tutorial. I have one small suggestion. For the
attention_scores_draft
, it might be helpful to include the necessary tensors to play with the function.