rasbt / LLMs-from-scratch

Implementing a ChatGPT-like LLM in PyTorch from scratch, step by step
https://www.manning.com/books/build-a-large-language-model-from-scratch
Other
26.89k stars 2.99k forks source link

Inconsistencies in unsqueeze operation description in the book and in notebook and its necessity (3.6.2 Implementing multi-head attention with weight splits) #61

Closed labdmitriy closed 6 months ago

labdmitriy commented 6 months ago

Hi @rasbt,

I found that implementation of the MultiHeadAttention class has the following line:

mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)

But there is only one unsqueeze operation in the notebook:

mask_unsqueezed = mask_bool.unsqueeze(0)

But as I understand we can skip unsqueeze operation at all because masked_fill_() method supports broadcasting

Thank you.

labdmitriy commented 6 months ago

Also I have a question - could you please explain why do we need to call contiguous() in the following line in MultiHeadAttention class:

context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
rasbt commented 6 months ago

mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)

Ah yes, this was unnecessary so I updated it to just mask_bool.unsqueeze(0) a while back. I will look into whether I can remove it altogether like you suggest. Thanks!

Also I have a question - could you please explain why do we need to call contiguous() in the following line in MultiHeadAttention class:

context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

Good question. This is because the way the memory is organized in this tensor; the .view() would raise an error. What you could do is

context_vec = context_vec.reshape(b, num_tokens, self.d_out)

This this is because (quoting from the documentation):

When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.

However, I haven't used .reshape elsewhere in this book so I wanted to stick with .view for consistency.

rasbt commented 6 months ago

Nice, it turns out you were right, the .unsqueeze(0) was indeed redundant. Love it, it makes the code even simpler and more readable!

labdmitriy commented 6 months ago

Sebastian, thanks a lot for your response,

Good question. This is because the way the memory is organized in this tensor; the .view() would raise an error

Yes, this question was asked because when I deleted .contiguous():

context_vec = context_vec.view(b, num_tokens, self.d_out)

I didn't have any errors and get the same results.

Only one another reason to convert to contiguous tensor that I found here was the following:

This create issues with parallel computations.

But I didn't find more detailed explanation. Could you please share your thoughts about it?

Thank you.

labdmitriy commented 1 month ago

@rasbt If it is possible could you please comment the last question? I've checked today again and found that still don't have any error even without .contiguous(). Thank you.

rasbt commented 1 month ago

Sorry, I must have missed your previous comment. It's interesting that you don't get any issues. Testing ch03 now on macOS with PyTorch 2.4 I see that it's fine without the .contiguous call now. However, in the MultiHeadAttentionCombinedQKV class in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb it still seems required:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 69
     57         return context_vec
     60 mha_combined_qkv = MultiHeadAttentionCombinedQKV(
     61     d_in=embed_dim,
     62     d_out=embed_dim,
   (...)
     66     qkv_bias=False
     67 ).to(device)
---> 69 out = mha_combined_qkv(embeddings)
     70 print(out.shape)

File ~/miniforge3/envs/book/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/book/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

Cell In[5], line 53, in MultiHeadAttentionCombinedQKV.forward(self, x)
     50 context_vec = context_vec.transpose(1, 2)
     52 # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)
---> 53 context_vec = context_vec.view(batch_size, num_tokens, embed_dim)
     55 context_vec = self.proj(context_vec)
     57 return context_vec

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
labdmitriy commented 1 month ago

Thanks a lot for your answer!

I found that for d_out > 2 where d_out % num_heads == 0, there is also an error if contiguous() is not used (so probably there is no error only for d_out == 2 without contiguous()):

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

I am still looking for intuitive explanation for it, if I find it I will write here. Thank you.

rasbt commented 1 month ago

Likewise, I am curious, too.

If you use the following two functions:

def print_tensor_info(tensor, name):
    print(f"{name}:")
    print(f"  Shape: {tensor.shape}")
    print(f"  Strides: {tensor.stride()}")
    print(f"  Is Contiguous: {tensor.is_contiguous()}\n")

def visualize_memory_layout(tensor, name):
    print_tensor_info(tensor, name)

    # Flatten the tensor to 1D and convert to numpy for visualization
    flat_tensor = tensor.contiguous().view(-1).cpu().numpy()

    plt.figure(figsize=(10, 1))
    plt.title(f"Memory Layout of {name}")
    plt.imshow(flat_tensor.reshape(1, -1), aspect='auto', cmap='viridis')
    plt.yticks([])
    plt.xlabel('Memory Location')
    plt.show()

and insert them as follows:

import torch
import torch.nn as nn

print(torch.__version__)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 

        visualize_memory_layout(context_vec.detach(), "before")
        visualize_memory_layout(context_vec.contiguous().detach(), "after")

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch = torch.rand(8, 1024, 768)

batch_size, context_length, d_in = batch.shape
d_out = 4
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

mha(batch)

You can see that the issue occurs if the circled stride value below gets larger than 1024:

Screenshot 2024-07-25 at 2 34 03 PM

This is the stride corresponding to the head dimension that we want to combine via the .view. Not sure why the magic number 1024, but maybe it's something related to those interacting during the .view call:

Screenshot 2024-07-25 at 2 35 53 PM
rasbt commented 1 month ago

A smaller example to reproduce:

a = torch.randn(100, 10, 5)
a = a.permute(1, 0, 2)

print("shape", a.shape)
print("stride", a.stride())
print("contiguous", a.is_contiguous())
a.view(10*100, -1)
shape torch.Size([10, 100, 5])
stride (5, 50, 1)
contiguous False
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[28], line 7
      5 print("stride", a.stride())
      6 print("contiguous", a.is_contiguous())
----> 7 a.view(10*100, -1)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
labdmitriy commented 1 month ago

Thank you for example! I will explore it too.

rasbt commented 1 month ago

Oh, I think I maybe understand this now. In the case of d_out=2 we don't have this issue, because the shape is

torch.Size([8, 1024, 2, 1])

Note that the two dimensions we want to combine via .view are the two last dimensions. So in this case 2, 1, which is fine. But if you increase d_out, it will be 2,2 etc., and then you get the issue with the memory layout because it needs to access elements across the 2 subspaces (otherwise with 1 element in that one dimension, there's only 1 subspace.).

Again a smaller example to reproduce, if you change the following example I had in my previous comment to

a = torch.randn(100, 10, 1)
a = a.permute(1, 2, 0)

print("shape", a.shape)
print("stride", a.stride())
print("contiguous", a.is_contiguous())

a.view(10, -1)

the .view() operation will work even though the memory is not contiguous.

shape torch.Size([10, 1, 100])
stride (1, 1, 10)
contiguous False
labdmitriy commented 1 month ago

Great explanation! Thank you Sebastian!

labdmitriy commented 1 month ago

Probably it is more difficult rule than you provided, because for example this transformation is also correct:

a = torch.randn(20, 5, 2)
a = a.permute(1, 2, 0)

print("shape", a.shape)
print("stride", a.stride())
print("contiguous", a.is_contiguous())

a.view(10, -1)
shape torch.Size([5, 2, 20])
stride (2, 1, 10)
contiguous False

I think that there is more intuitive way to understand this rule though I haven't found it yet: https://pytorch.org/docs/stable/generated/torch.Tensor.view.html

rasbt commented 1 month ago

Maybe I should have used .reshape to hide that complexity 😅

labdmitriy commented 1 month ago

Hi @rasbt,

Just want to share some additional resources that I found to be useful to me about this theme:

Thank you.

rasbt commented 1 month ago

Awesome, thanks a lot! We have a weekend coming up ... that's perfect!