pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
484 stars 24 forks source link

FlexAttention BlockMask creation in DataLoader #78

Closed ViktorooReps closed 1 week ago

ViktorooReps commented 1 week ago

Hi!

I am not getting something, and confusing error does not help..

Why does this code work fine (see PaddingBlockMask in the next code section):

hidden_dim = 1024

class DummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(128, 128 * 16)

    def forward(self, x, mask=None):
        B, S = x.shape
        hidden = self.emb(x)
        for _ in range(12):
            hidden = hidden.view(B, S, 16, 128).transpose(1, 2)
            hidden = flex_attention(hidden, hidden, hidden, block_mask=mask)
            hidden = hidden.transpose(1, 2).contiguous().view(B, S, -1)
        return hidden

model = DummyModel().cuda()
model = torch.compile(model)

tokens = torch.randint(0, 127, size=(128, 128), device='cpu')
B, S = tokens.shape
mask = create_block_mask(
        mask_mod=and_masks(causal_block_mask, PaddingBlockMask(tokens, padding_token=0, device='cuda')),
        B=B,
        H=None,                    
        Q_LEN=S,
        KV_LEN=S,
        device='cuda'
    )  

model(tokens.to('cuda'), mask=mask.to('cuda')) is not None

And this one breaks:

def calculate_intervals(boundary_mask):
    # Find cumulative sum of the gaps to identify different intervals
    interval_idx = torch.cumsum(boundary_mask, dim=-1) + 1
    shifted = interval_idx.roll(1)
    shifted[:, 0] = interval_idx[:, 0]

    return shifted

class IntervalBlockMask:
    def __init__(self, interval_mask, *, device: str, max_distance: int = 0):
        self.boundary_mask = ~interval_mask.to(device)
        self.interval_idx = calculate_intervals(self.boundary_mask)
        self.max_distance = max_distance

    def __call__(self, b, h, q_idx, kv_idx):
        q_interval_id = self.interval_idx[b, q_idx]
        kv_interval_id = self.interval_idx[b, kv_idx]

        return (self.boundary_mask[b, kv_idx] 
                | (q_idx - kv_idx < self.max_distance) 
                | torch.eq(q_interval_id, kv_interval_id))

class PaddingBlockMask:
    def __init__(self, tokens, *, device: str, padding_token: int):
        self.tokens = tokens.to(device)
        self.padding_token = padding_token

    def __call__(self, b, h, q_idx, kv_idx):
        q_token = self.tokens[b, q_idx]
        kv_token = self.tokens[b, kv_idx]
        return torch.ne(q_token, self.padding_token) & torch.ne(kv_token, self.padding_token)

def causal_block_mask(b, h, q_idx, kv_idx):
    # can only attend to past tokens and self
    return q_idx >= kv_idx

class Collator:
    def __init__(
            self, 
            padding_token: int, 
            mask_words: bool = False, 
            mask_segments: bool = False,
            device: str = 'cpu'
    ):
        self.padding_token = padding_token
        self.mask_words = mask_words
        self.mask_segments = mask_segments
        self.device = device

    def __call__(self, batch):
        masks = [causal_block_mask]

        tokens = tuple(item['tokens'].to(self.device) for item in batch)
        tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=self.padding_token)
        masks.append(PaddingBlockMask(tokens_padded, padding_token=self.padding_token, device=self.device))

        if self.mask_segments:
            segment_mask = tuple(item['segment_mask'] for item in batch)
            segment_mask_padded = pad_sequence(segment_mask, batch_first=True, padding_value=0)
            masks.append(IntervalBlockMask(segment_mask_padded, device=self.device))

        if self.mask_words:
            specials_mask = tuple(item['specials_mask'] for item in batch)
            specials_mask_padded = pad_sequence(specials_mask, batch_first=True, padding_value=0)
            masks.append(IntervalBlockMask(~specials_mask_padded, device=self.device))

        batch_size, seq_length = tokens_padded.shape

        return {
            'tokens': tokens_padded,
            'block_mask': create_block_mask(
                mask_mod=and_masks(*masks),
                B=batch_size,
                H=None,                    
                Q_LEN=seq_length,
                KV_LEN=seq_length,
                device=self.device
            )  
        }

base = 128
mult_min = 2
mult_max = 100
n_trials = 10

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

data = []
for length_mult in tqdm(range(mult_min, mult_max), total=mult_max - mult_min):
    length = base * length_mult
    dataset = Dataset('data/tinyshakespeare/train', start_token=2, target_length=length)

    # Different configurations
    for name, mask_words, mask_segments in [
        ('normal', False, False),
        #('mask_words', True, False),
        #('mask_words_segments', True, True)
    ]:
        collator = Collator(padding_token=2, mask_words=mask_words, mask_segments=mask_segments, device='cuda')
        loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=collator)

        for batch, _ in zip(loader, range(n_trials)):
            tokens = batch['tokens'].to('cuda')
            block_mask = batch['block_mask'].to('cuda')

            print(tokens, block_mask)

            torch.cuda.empty_cache()

            start_time = time.time()
            start_memory = torch.cuda.memory_allocated()

            model(tokens, block_mask)

            time_taken = time.time() - start_time
            end_memory = torch.cuda.memory_allocated()
            memory_used = end_memory - start_memory

            data.append({
                'model': name,
                'sequence_length': length,
                'time_taken': time_taken,
                'memory_used': memory_used
            })

The error I am getting is:

unknown:0: unknown: block: [2,0,0], thread: [0,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [1,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [2,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [3,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [5,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [7,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [8,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [9,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
unknown:0: unknown: block: [2,0,0], thread: [10,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128` failed.
...

And so on, at the end is:

RuntimeError: CUDA error: device-side assert triggered

Am I using BlockMask wrong? Is it not supposed to work with DataLoader? Ideally, I would like to create masks in advance with num_workers > 0.

The torch version is 2.6.0.dev20241112+cu121, I am working in Jupyter Notebook.

Thank you so much for any help!

drisspg commented 1 week ago

Hey could you correct the imports so that I can run the script, I tried to add some of the ones that were missing but it would but didnt get all of em

ViktorooReps commented 1 week ago

@drisspg thank you for the reply! But I have some other issue, I'll get back to this a bit later (I'll reopen this issue then).

Could you have a look at this another issue I am having: #79 ?

ViktorooReps commented 1 week ago

Btw, I think I had some issue due to the fact that masks are implemented as classes (FlexAttention becomes super slow)

I'll investigate later