lucidrains / perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
MIT License
1.08k stars 134 forks source link

just a suggestion #56

Open seyeeet opened 2 years ago

seyeeet commented 2 years ago

Hi I like to start with thanking you for such a great work with a lot of great implementations. I have a small suggestion. I suggest for all your codes/modules try to add if __name__ == "__main__": so that if someone just wants to use one file/module can easily try that without having going through whole implementations. for example I am trying to use the this, in case of having a if __name__ == "__main__": I can easily try to run a random input and see how it will work. This will increase the usability with a huge amount.

Keep up the great work :)

lucidrains commented 2 years ago

@seyeeet Hi Sey! So the code in my readme should always work, no exceptions. But you are right, I can definitely add that code to a main executable section in each file

lucidrains commented 2 years ago

Just run

import torch
from perceiver_pytorch import PerceiverIO

model = PerceiverIO(
    dim = 32,                    # dimension of sequence to be encoded
    queries_dim = 32,            # dimension of decoder queries
    logits_dim = 100,            # dimension of final logits
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
)

seq = torch.randn(1, 512, 32)
queries = torch.randn(128, 32)

logits = model(seq, queries = queries) # (1, 128, 100) - (batch, decoder seq, logits dim)
seyeeet commented 2 years ago

Hi Phil, yes, I am already agreed with you and aware of it. It was just a suggestion to add to the usability for future codes if you decide it is useful :) please feel free to close it as you like and you already know it :)

wangxiao-star commented 2 years ago

How to handle multimodal data