Closed labdmitriy closed 6 months ago
Also I have a question - could you please explain why do we need to call contiguous()
in the following line in MultiHeadAttention
class:
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
Ah yes, this was unnecessary so I updated it to just mask_bool.unsqueeze(0)
a while back. I will look into whether I can remove it altogether like you suggest. Thanks!
Also I have a question - could you please explain why do we need to call contiguous() in the following line in MultiHeadAttention class:
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
Good question. This is because the way the memory is organized in this tensor; the .view()
would raise an error. What you could do is
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
This this is because (quoting from the documentation):
When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.
However, I haven't used .reshape
elsewhere in this book so I wanted to stick with .view
for consistency.
Nice, it turns out you were right, the .unsqueeze(0)
was indeed redundant. Love it, it makes the code even simpler and more readable!
Sebastian, thanks a lot for your response,
Good question. This is because the way the memory is organized in this tensor; the .view() would raise an error
Yes, this question was asked because when I deleted .contiguous()
:
context_vec = context_vec.view(b, num_tokens, self.d_out)
I didn't have any errors and get the same results.
Only one another reason to convert to contiguous tensor that I found here was the following:
This create issues with parallel computations.
But I didn't find more detailed explanation. Could you please share your thoughts about it?
Thank you.
@rasbt
If it is possible could you please comment the last question? I've checked today again and found that still don't have any error even without .contiguous()
.
Thank you.
Sorry, I must have missed your previous comment. It's interesting that you don't get any issues. Testing ch03 now on macOS with PyTorch 2.4 I see that it's fine without the .contiguous
call now. However, in the MultiHeadAttentionCombinedQKV
class in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb it still seems required:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], line 69
57 return context_vec
60 mha_combined_qkv = MultiHeadAttentionCombinedQKV(
61 d_in=embed_dim,
62 d_out=embed_dim,
(...)
66 qkv_bias=False
67 ).to(device)
---> 69 out = mha_combined_qkv(embeddings)
70 print(out.shape)
File ~/miniforge3/envs/book/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/book/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
Cell In[5], line 53, in MultiHeadAttentionCombinedQKV.forward(self, x)
50 context_vec = context_vec.transpose(1, 2)
52 # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)
---> 53 context_vec = context_vec.view(batch_size, num_tokens, embed_dim)
55 context_vec = self.proj(context_vec)
57 return context_vec
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Thanks a lot for your answer!
I found that for d_out > 2
where d_out % num_heads == 0
, there is also an error if contiguous()
is not used (so probably there is no error only for d_out == 2
without contiguous()
):
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
I am still looking for intuitive explanation for it, if I find it I will write here. Thank you.
Likewise, I am curious, too.
If you use the following two functions:
def print_tensor_info(tensor, name):
print(f"{name}:")
print(f" Shape: {tensor.shape}")
print(f" Strides: {tensor.stride()}")
print(f" Is Contiguous: {tensor.is_contiguous()}\n")
def visualize_memory_layout(tensor, name):
print_tensor_info(tensor, name)
# Flatten the tensor to 1D and convert to numpy for visualization
flat_tensor = tensor.contiguous().view(-1).cpu().numpy()
plt.figure(figsize=(10, 1))
plt.title(f"Memory Layout of {name}")
plt.imshow(flat_tensor.reshape(1, -1), aspect='auto', cmap='viridis')
plt.yticks([])
plt.xlabel('Memory Location')
plt.show()
and insert them as follows:
import torch
import torch.nn as nn
print(torch.__version__)
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
visualize_memory_layout(context_vec.detach(), "before")
visualize_memory_layout(context_vec.contiguous().detach(), "after")
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
torch.manual_seed(123)
batch = torch.rand(8, 1024, 768)
batch_size, context_length, d_in = batch.shape
d_out = 4
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
mha(batch)
You can see that the issue occurs if the circled stride value below gets larger than 1024:
This is the stride corresponding to the head dimension that we want to combine via the .view
. Not sure why the magic number 1024, but maybe it's something related to those interacting during the .view
call:
A smaller example to reproduce:
a = torch.randn(100, 10, 5)
a = a.permute(1, 0, 2)
print("shape", a.shape)
print("stride", a.stride())
print("contiguous", a.is_contiguous())
a.view(10*100, -1)
shape torch.Size([10, 100, 5])
stride (5, 50, 1)
contiguous False
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[28], line 7
5 print("stride", a.stride())
6 print("contiguous", a.is_contiguous())
----> 7 a.view(10*100, -1)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Thank you for example! I will explore it too.
Oh, I think I maybe understand this now. In the case of d_out=2
we don't have this issue, because the shape is
torch.Size([8, 1024, 2, 1])
Note that the two dimensions we want to combine via .view
are the two last dimensions. So in this case 2, 1, which is fine. But if you increase d_out
, it will be 2,2 etc., and then you get the issue with the memory layout because it needs to access elements across the 2 subspaces (otherwise with 1 element in that one dimension, there's only 1 subspace.).
Again a smaller example to reproduce, if you change the following example I had in my previous comment to
a = torch.randn(100, 10, 1)
a = a.permute(1, 2, 0)
print("shape", a.shape)
print("stride", a.stride())
print("contiguous", a.is_contiguous())
a.view(10, -1)
the .view()
operation will work even though the memory is not contiguous.
shape torch.Size([10, 1, 100])
stride (1, 1, 10)
contiguous False
Great explanation! Thank you Sebastian!
Probably it is more difficult rule than you provided, because for example this transformation is also correct:
a = torch.randn(20, 5, 2)
a = a.permute(1, 2, 0)
print("shape", a.shape)
print("stride", a.stride())
print("contiguous", a.is_contiguous())
a.view(10, -1)
shape torch.Size([5, 2, 20])
stride (2, 1, 10)
contiguous False
I think that there is more intuitive way to understand this rule though I haven't found it yet: https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
Maybe I should have used .reshape
to hide that complexity 😅
Hi @rasbt,
Just want to share some additional resources that I found to be useful to me about this theme:
Thank you.
Awesome, thanks a lot! We have a weekend coming up ... that's perfect!
Hi @rasbt,
I found that implementation of the
MultiHeadAttention
class has the following line:But there is only one unsqueeze operation in the notebook:
But as I understand we can skip unsqueeze operation at all because
masked_fill_()
method supports broadcastingThank you.