lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.84k stars 418 forks source link

[Feature request] Efficient stochastic path without unused layers #294

Open Aceticia opened 1 week ago

Aceticia commented 1 week ago

DINO v2 finds that high values of stochastic depth is very helpful for larger models in terms of performance and they also gave an efficient implementation that only operates on the un-masked samples of a batch here, which is very simple:

def drop_add_residual_stochastic_depth(
    x: Tensor,
    residual_func: Callable[[Tensor], Tensor],
    sample_drop_ratio: float = 0.0,
) -> Tensor:
    # 1) extract subset using permutation
    b, n, d = x.shape
    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
    x_subset = x[brange]

    # 2) apply residual_func to get residual
    residual = residual_func(x_subset)

    x_flat = x.flatten(1)
    residual = residual.flatten(1)

    residual_scale_factor = b / sample_subset_size

    # 3) add the residual
    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
    return x_plus_residual.view_as(x)

In practice, with up to stochastic_depth=0.4, the memory usage almost halves.

In this repo, there is a stochastic depth provided, where the layers are dropped altogether. This also achieves similar effect as the DINO v2 implementation in that masked out samples of a batch don't waste compute. However, this drops entire layers and thus we are forced to use find_unused_parameters=True when training with DDP, which would cause further overheads... besides, dropping entire layer across all batches feels kinda weird and might introduce biases.

I can contribute something and integrate this into the attention and MLP layers. What do you think? Is there any other reasons that you keep the entire layer drop (apart from the potential overhead when drop is low)?

lucidrains commented 1 week ago

@Aceticia hello again Xujin/Chris

stochastic depth is popular in some circles for sure

what do you think about just forcing the parameters to be used by sending in a single dummy token, multiplying the output by 0 and summing it to the stream? that should fix the ddp issue?

Aceticia commented 1 week ago

Hello again! I go by Chris :D

Sounds like a good solution, I can't really think of any side effects.

lucidrains commented 1 week ago

@Aceticia ok Chris i'll add it later this evening and you can let me know if that unused parameters issue persists

lucidrains commented 1 week ago

@Aceticia did you see anything interesting when splitting dimensions for alibi across heads?

Aceticia commented 1 week ago

@Aceticia did you see anything interesting when splitting dimensions for alibi across heads?

I tried it out, didn't have time for a complete run but sadly I don't see much differences from just using alibi in time. We made the compromise to use consistent time ordering across samples and use rotary pos emb in time, and a learned positional embedding across space and it's the best we have yet.

Can't spend forever on this - sorry to have wasted some of your time on this. Good knowledge though.

lucidrains commented 1 week ago

@Aceticia no problem! just your sharing this makes it worth it

thanks!