rasbt / machine-learning-book

Code Repository for Machine Learning with PyTorch and Scikit-Learn
https://sebastianraschka.com/books/#machine-learning-with-pytorch-and-scikit-learn
MIT License
3.12k stars 1.14k forks source link

Confusing matrix operations in Chpt. 16 #170

Open xiongtx opened 3 months ago

xiongtx commented 3 months ago

I find the matrix operations in Chpt. 16 confusing. For example, instead of:

keys = U_key.matmul(embedded_sentence.T).T
values = U_value.matmul(embedded_sentence.T).T
omega_23 = query_2.dot(keys[2])
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)

It's clearer to do:

queries = embedded_sentence @ U_query
keys = embedded_sentence @ U_key 
values = embedded_sentence @ U_value
omega = queries @ keys.T
attention_weights = F.softmax(omega / d**0.5, dim=0)

W/ multi-head attention, instead of:

stacked_inputs = embedded_sentence.T.repeat(8, 1, 1)
multihead_keys = torch.bmm(multihead_U_key, stacked_inputs)
multihead_keys = multihead_keys.permute(0, 2, 1)
# Eventually giving up...
multihead_z_2 = torch.rand(8, 16)

we can just do:

multihead_queries= embedded_sentence @ multihead_U_query
multihead_keys = embedded_sentence @ multihead_U_key
multihead_values = embedded_sentence @ multihead_U_vvalue
multihead_weights = F.softmax(multihead_queries @ multihead_keys.transpose(1, 2) / d**0.5, dim=1)
multihead_z = multihead_weights @ multhead_values

which makes it clear that the multihead case is analogous to the single-head case.

rasbt commented 2 months ago

Thanks for the comment, and I 100% agree. Not sure why I made it unnecessarily complicated there. In my other book (Build an LLM from Scratch), I am using the more legible version similar to what you suggest: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb