idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

Model diff between commits #79

Closed danieltudosiu closed 3 years ago

danieltudosiu commented 3 years ago

Hi there,

We are using your package as a dependency for performer-PyTorch and we have observed a model diff between version 0.3 and commit 68d26c0e971dc2de49bca079c373ed27b3e383af.

The model increases in size by 26%, the compute the time of an epoch increase by 350%, and also our old checkpoints are not compatible anymore.

My question is, are those effects intended or is there a problem?

Thanks!

angeloskath commented 3 years ago

Hi Daniel,

Could you share some code that replicates this? Just as a sanity check I ran the following test in brand new environments installing from the v0.3.0 tag and from master.

from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.feature_maps import Favor
from fast_transformers.masking import TriangularCausalMask
import torch

if __name__ == "__main__":
    tr = TransformerEncoderBuilder.from_kwargs(
        n_layers=4,
        n_heads=4,
        query_dimensions=32,
        feature_map=Favor.factory(256),
        attention_type="causal-linear"
    ).get().cuda()
    x = torch.randn(10, 1000, 4*64).cuda()
    mask = TriangularCausalMask(1000, device="cuda")
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # warmup
    tr(x, attn_mask=mask).sum().backward()

    start.record()
    for i in range(10):
        tr(x, attn_mask=mask).sum().backward()
    end.record()
    torch.cuda.synchronize()
    print("Elapsed time:", start.elapsed_time(end), "ms")
    print("# params:", sum(v.numel() for v in tr.state_dict().values()))
    print("Mem allocated:", torch.cuda.max_memory_allocated())

The results are as follows (on RTX 2060S) for v0.3.0:

Elapsed time: 3654.369384765625 ms
# params: 2912768
Mem allocated: 1815452160

and for master:

Elapsed time: 1977.833251953125 ms
# params: 2912768
Mem allocated: 1815452160

So, to sum up, my first impression is that I didn't break anything. But if you could provide a test case then I 'd be happy to check.

Thanks, Angelos

danieltudosiu commented 3 years ago

Closing as it was traced back to performer-pytorch package.